torchtitan
torchtitan copied to clipboard
[POC] Showed more memory efficient FSDP wrapping
Stack from ghstack (oldest at bottom):
- #459
- -> #382
This requires https://github.com/pytorch/pytorch/pull/127786.
Experiment
- Llama3-8B on 8xH100, 1D FSDP, local batch size 2, selective op AC,
compiled_rmsnorm,torch.compileenabled per transformer block, fused AdamW- With this PR (68.09 GiB reserved memory):
[rank0]:2024-07-11 10:55:21,533 - root - INFO - step: 1 loss: 12.2308 memory: 60.27GiB(63.41%) wps: 233 mfu: 1.37% [rank0]:2024-07-11 10:55:21,534 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-07-11 10:55:24,063 - root - INFO - step: 2 loss: 12.0520 memory: 68.09GiB(71.65%) wps: 6,479 mfu: 37.94% [rank0]:2024-07-11 10:55:26,596 - root - INFO - step: 3 loss: 11.7165 memory: 68.09GiB(71.65%) wps: 6,470 mfu: 37.89% [rank0]:2024-07-11 10:55:29,139 - root - INFO - step: 4 loss: 11.3078 memory: 68.09GiB(71.65%) wps: 6,445 mfu: 37.74% [rank0]:2024-07-11 10:55:31,681 - root - INFO - step: 5 loss: 10.8763 memory: 68.09GiB(71.65%) wps: 6,446 mfu: 37.75%- Without this PR (69.04 GiB reserved memory):
[rank0]:2024-07-11 11:03:35,749 - root - INFO - step: 1 loss: 12.2646 memory: 61.21GiB(64.41%) wps: 305 mfu: 1.79% [rank0]:2024-07-11 11:03:35,749 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-07-11 11:03:38,284 - root - INFO - step: 2 loss: 12.0713 memory: 69.04GiB(72.65%) wps: 6,464 mfu: 37.85% [rank0]:2024-07-11 11:03:40,821 - root - INFO - step: 3 loss: 11.7398 memory: 69.04GiB(72.65%) wps: 6,460 mfu: 37.83% [rank0]:2024-07-11 11:03:43,356 - root - INFO - step: 4 loss: 11.3238 memory: 69.04GiB(72.65%) wps: 6,462 mfu: 37.84% [rank0]:2024-07-11 11:03:45,898 - root - INFO - step: 5 loss: 10.9178 memory: 69.04GiB(72.65%) wps: 6,448 mfu: 37.76% - Llama3-8B on 8xH100, 1D FSDP, local batch size 1, no AC,
compiled_rmsnorm,torch.compileenabled per transformer block, fused AdamW- With this PR (68.36 GiB reserved memory):
[rank0]:2024-07-11 12:53:24,747 - root - INFO - step: 1 loss: 12.2439 memory: 58.58GiB(61.63%) wps: 148 mfu: 0.87% [rank0]:2024-07-11 12:53:24,750 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-07-11 12:53:26,042 - root - INFO - step: 2 loss: 12.0557 memory: 68.36GiB(71.93%) wps: 6,342 mfu: 37.14% [rank0]:2024-07-11 12:53:27,338 - root - INFO - step: 3 loss: 11.7423 memory: 68.36GiB(71.93%) wps: 6,324 mfu: 37.03% [rank0]:2024-07-11 12:53:28,630 - root - INFO - step: 4 loss: 11.3138 memory: 68.36GiB(71.93%) wps: 6,343 mfu: 37.15% [rank0]:2024-07-11 12:53:29,927 - root - INFO - step: 5 loss: 10.9011 memory: 68.36GiB(71.93%) wps: 6,319 mfu: 37.00%- Without this PR (67.50 GiB reserved memory):
[rank0]:2024-07-11 12:50:09,792 - root - INFO - step: 1 loss: 12.2539 memory: 63.58GiB(66.90%) wps: 146 mfu: 0.86% [rank0]:2024-07-11 12:50:09,792 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:2024-07-11 12:50:11,087 - root - INFO - step: 2 loss: 12.0905 memory: 67.50GiB(71.02%) wps: 6,328 mfu: 37.06% [rank0]:2024-07-11 12:50:12,385 - root - INFO - step: 3 loss: 11.7652 memory: 67.50GiB(71.02%) wps: 6,314 mfu: 36.97% [rank0]:2024-07-11 12:50:13,680 - root - INFO - step: 4 loss: 11.2644 memory: 67.50GiB(71.02%) wps: 6,327 mfu: 37.05% [rank0]:2024-07-11 12:50:14,978 - root - INFO - step: 5 loss: 10.8718 memory: 67.50GiB(71.02%) wps: 6,315 mfu: 36.98%
For some reason, without AC, the new wrapping actually uses more memory. This could be due to memory fragmentation or compile reasons and needs more investigation.