torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[POC] Showed more memory efficient FSDP wrapping

Open awgu opened this issue 1 year ago • 0 comments

Stack from ghstack (oldest at bottom):

  • #459
  • -> #382

This requires https://github.com/pytorch/pytorch/pull/127786.

Experiment

  • Llama3-8B on 8xH100, 1D FSDP, local batch size 2, selective op AC, compiled_rmsnorm, torch.compile enabled per transformer block, fused AdamW
    • With this PR (68.09 GiB reserved memory):
    [rank0]:2024-07-11 10:55:21,533 - root - INFO - step:  1  loss: 12.2308  memory: 60.27GiB(63.41%)  wps: 233  mfu: 1.37%
    [rank0]:2024-07-11 10:55:21,534 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-07-11 10:55:24,063 - root - INFO - step:  2  loss: 12.0520  memory: 68.09GiB(71.65%)  wps: 6,479  mfu: 37.94%
    [rank0]:2024-07-11 10:55:26,596 - root - INFO - step:  3  loss: 11.7165  memory: 68.09GiB(71.65%)  wps: 6,470  mfu: 37.89%
    [rank0]:2024-07-11 10:55:29,139 - root - INFO - step:  4  loss: 11.3078  memory: 68.09GiB(71.65%)  wps: 6,445  mfu: 37.74%
    [rank0]:2024-07-11 10:55:31,681 - root - INFO - step:  5  loss: 10.8763  memory: 68.09GiB(71.65%)  wps: 6,446  mfu: 37.75%
    
    • Without this PR (69.04 GiB reserved memory):
    [rank0]:2024-07-11 11:03:35,749 - root - INFO - step:  1  loss: 12.2646  memory: 61.21GiB(64.41%)  wps: 305  mfu: 1.79%
    [rank0]:2024-07-11 11:03:35,749 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-07-11 11:03:38,284 - root - INFO - step:  2  loss: 12.0713  memory: 69.04GiB(72.65%)  wps: 6,464  mfu: 37.85%
    [rank0]:2024-07-11 11:03:40,821 - root - INFO - step:  3  loss: 11.7398  memory: 69.04GiB(72.65%)  wps: 6,460  mfu: 37.83%
    [rank0]:2024-07-11 11:03:43,356 - root - INFO - step:  4  loss: 11.3238  memory: 69.04GiB(72.65%)  wps: 6,462  mfu: 37.84%
    [rank0]:2024-07-11 11:03:45,898 - root - INFO - step:  5  loss: 10.9178  memory: 69.04GiB(72.65%)  wps: 6,448  mfu: 37.76%
    
  • Llama3-8B on 8xH100, 1D FSDP, local batch size 1, no AC, compiled_rmsnorm, torch.compile enabled per transformer block, fused AdamW
    • With this PR (68.36 GiB reserved memory):
    [rank0]:2024-07-11 12:53:24,747 - root - INFO - step:  1  loss: 12.2439  memory: 58.58GiB(61.63%)  wps: 148  mfu: 0.87%
    [rank0]:2024-07-11 12:53:24,750 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-07-11 12:53:26,042 - root - INFO - step:  2  loss: 12.0557  memory: 68.36GiB(71.93%)  wps: 6,342  mfu: 37.14%
    [rank0]:2024-07-11 12:53:27,338 - root - INFO - step:  3  loss: 11.7423  memory: 68.36GiB(71.93%)  wps: 6,324  mfu: 37.03%
    [rank0]:2024-07-11 12:53:28,630 - root - INFO - step:  4  loss: 11.3138  memory: 68.36GiB(71.93%)  wps: 6,343  mfu: 37.15%
    [rank0]:2024-07-11 12:53:29,927 - root - INFO - step:  5  loss: 10.9011  memory: 68.36GiB(71.93%)  wps: 6,319  mfu: 37.00%
    
    • Without this PR (67.50 GiB reserved memory):
    [rank0]:2024-07-11 12:50:09,792 - root - INFO - step:  1  loss: 12.2539  memory: 63.58GiB(66.90%)  wps: 146  mfu: 0.86%
    [rank0]:2024-07-11 12:50:09,792 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
    [rank0]:2024-07-11 12:50:11,087 - root - INFO - step:  2  loss: 12.0905  memory: 67.50GiB(71.02%)  wps: 6,328  mfu: 37.06%
    [rank0]:2024-07-11 12:50:12,385 - root - INFO - step:  3  loss: 11.7652  memory: 67.50GiB(71.02%)  wps: 6,314  mfu: 36.97%
    [rank0]:2024-07-11 12:50:13,680 - root - INFO - step:  4  loss: 11.2644  memory: 67.50GiB(71.02%)  wps: 6,327  mfu: 37.05%
    [rank0]:2024-07-11 12:50:14,978 - root - INFO - step:  5  loss: 10.8718  memory: 67.50GiB(71.02%)  wps: 6,315  mfu: 36.98%
    

For some reason, without AC, the new wrapping actually uses more memory. This could be due to memory fragmentation or compile reasons and needs more investigation.

awgu avatar Jun 03 '24 19:06 awgu