[Not for land] Added changes for GPT-2 perf
Stack from ghstack (oldest at bottom):
- -> #533
- #532
Credit: @felipemello1 for the previous token chunked cross entropy Credit: @Chillee for the new token chunked cross entropy
Running on 4xH100s:
Without these changes (torch.compile), the max local batch size is 5:
[rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:2024-08-19 11:10:33,811 - root - INFO - step: 1 loss: 12.2365 memory: 81.67GiB(85.93%) wps: 5,380 mfu: 1.09%
[rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10 loss: 12.1951 memory: 81.67GiB(85.93%) wps: 111,770 mfu: 22.68%
[rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20 loss: 11.9455 memory: 81.67GiB(85.93%) wps: 111,714 mfu: 22.67%
[rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30 loss: 11.0407 memory: 81.67GiB(85.93%) wps: 112,194 mfu: 22.76%
[rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40 loss: 9.9520 memory: 81.67GiB(85.93%) wps: 112,109 mfu: 22.75%
[rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50 loss: 9.3392 memory: 81.67GiB(85.93%) wps: 112,218 mfu: 22.77%
[rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60 loss: 8.7255 memory: 81.67GiB(85.93%) wps: 112,198 mfu: 22.77%
[rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70 loss: 8.1659 memory: 81.67GiB(85.93%) wps: 112,234 mfu: 22.77%
[rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80 loss: 7.8037 memory: 81.67GiB(85.93%) wps: 111,802 mfu: 22.68%
[rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90 loss: 7.5327 memory: 81.67GiB(85.93%) wps: 111,937 mfu: 22.71%
[rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100 loss: 7.3730 memory: 81.67GiB(85.93%) wps: 111,803 mfu: 22.69%
Without these changes, no compile
Without these changes (no torch.compile), local batch size 5:
[rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:2024-08-19 14:24:38,558 - root - INFO - step: 1 loss: 12.2581 memory: 86.47GiB(90.99%) wps: 6,393 mfu: 1.30%
[rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10 loss: 12.2099 memory: 86.48GiB(90.99%) wps: 98,305 mfu: 19.95%
[rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20 loss: 11.9421 memory: 86.48GiB(90.99%) wps: 98,230 mfu: 19.93%
[rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30 loss: 11.0090 memory: 86.48GiB(90.99%) wps: 98,435 mfu: 19.97%
[rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40 loss: 9.9780 memory: 86.48GiB(90.99%) wps: 99,064 mfu: 20.10%
[rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50 loss: 9.3572 memory: 86.48GiB(90.99%) wps: 98,813 mfu: 20.05%
[rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60 loss: 8.7479 memory: 86.48GiB(90.99%) wps: 96,567 mfu: 19.59%
[rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70 loss: 8.1769 memory: 86.48GiB(90.99%) wps: 98,604 mfu: 20.01%
[rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80 loss: 7.8070 memory: 86.48GiB(90.99%) wps: 98,579 mfu: 20.00%
[rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90 loss: 7.5329 memory: 86.48GiB(90.99%) wps: 98,743 mfu: 20.04%
[rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100 loss: 7.3700 memory: 86.48GiB(90.99%) wps: 98,818 mfu: 20.05%
With these changes (torch.compile), local batch size 32:
[rank0]:2024-09-06 19:48:58,342 - root - INFO - Training starts at step 1, with local batch size 32, global batch size 128, sequence length 8192, total steps 50 (warmup 200)
[rank0]:2024-09-06 19:49:08,904 - root - INFO - step: 1 loss: 12.2442 memory: 79.40GiB(83.54%) wps: 24,819 mfu: 5.04%
[rank0]:2024-09-06 19:49:08,904 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-09-06 19:49:23,127 - root - INFO - step: 10 loss: 12.1998 memory: 80.81GiB(85.03%) wps: 165,880 mfu: 33.66%
[rank0]:2024-09-06 19:49:38,946 - root - INFO - step: 20 loss: 11.9284 memory: 80.81GiB(85.03%) wps: 165,732 mfu: 33.63%
[rank0]:2024-09-06 19:49:54,764 - root - INFO - step: 30 loss: 10.9587 memory: 80.81GiB(85.03%) wps: 165,733 mfu: 33.63%
[rank0]:2024-09-06 19:50:10,566 - root - INFO - step: 40 loss: 9.8493 memory: 80.81GiB(85.03%) wps: 165,904 mfu: 33.66%
[rank0]:2024-09-06 19:50:26,973 - root - INFO - step: 50 loss: 9.2317 memory: 80.81GiB(85.03%) wps: 159,786 mfu: 32.42%
Old Results
With these changes, we can use local batch size 16:
[rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]: warnings.warn(
[rank0]:2024-08-19 11:16:15,523 - root - INFO - step: 1 loss: 12.2386 memory: 72.29GiB(76.06%) wps: 21,887 mfu: 4.44%
[rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10 loss: 12.1966 memory: 72.30GiB(76.07%) wps: 168,174 mfu: 34.12%
[rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20 loss: 11.9229 memory: 72.30GiB(76.07%) wps: 168,196 mfu: 34.13%
[rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30 loss: 10.9399 memory: 72.30GiB(76.07%) wps: 168,144 mfu: 34.12%
[rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40 loss: 9.8742 memory: 72.30GiB(76.07%) wps: 167,898 mfu: 34.07%
[rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50 loss: 9.2517 memory: 72.30GiB(76.07%) wps: 168,130 mfu: 34.11%
[rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60 loss: 8.6441 memory: 72.30GiB(76.07%) wps: 168,435 mfu: 34.18%
[rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70 loss: 8.0827 memory: 72.30GiB(76.07%) wps: 168,927 mfu: 34.28%
[rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80 loss: 7.7330 memory: 72.30GiB(76.07%) wps: 168,772 mfu: 34.24%
[rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90 loss: 7.4835 memory: 72.30GiB(76.07%) wps: 162,008 mfu: 32.87%
[rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100 loss: 7.3274 memory: 72.30GiB(76.07%) wps: 167,963 mfu: 34.08%
22.7% MFU -> 34.1% MFU
Great results! Consider using .split(8192//32, dim=1) instead of .chunk(16).
There was a huge difference in reserved memory in my experiments (no change in active memory though)
FYI, compiling loss + model together should yield much better results than compiling the model alone, if this is whats happening.
instead of doing: torch.compile(model)
do something like:
@torch.compile()
def loss_step(input,label):
output = self.model(input)
loss = calculate_loss(output, label)
loss.backward()
What we found is that using torch.compile on the cross entropy loss alone has great memory benefits (but not better than chunked): https://fb.workplace.com/groups/257735836456307/permalink/708422718054281/ And compiling model + loss together almost doubles toks/second: https://github.com/pytorch/torchtune/issues/1228#issuecomment-2277151232
But the best results for us is compiling only the model + using the chunked cross entropy. If we compile everything, then the results of chunked cross entropy are lost.
If I try to compile both the output linear and cross entropy loss together instead of just compiling the cross entropy loss, I get OOMs at the same batch size.
My uneducated guess is that the optimizations they made for CrossEntropyLoss accounts only for the loss being compiled on its own. Details of their implementation here: https://fb.workplace.com/groups/257735836456307/permalink/708422718054281/
Llama3-8B
With these changes:
[rank0]:2024-08-21 08:44:32,865 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2024-08-21 08:44:32,897 - root - INFO - Compiling each TransformerBlock with torch.compile
[rank0]:2024-08-21 08:44:32,953 - root - INFO - Applied FSDP to the model
[rank0]:NCCL version 2.21.5+cuda12.0
[rank0]:2024-08-21 08:44:45,742 - root - INFO - GPU memory usage for model: 3.99GiB(4.20%)
[rank0]:2024-08-21 08:44:45,743 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]: warnings.warn(
[rank0]:2024-08-21 08:46:07,756 - root - INFO - step: 1 loss: 12.2044 memory: 71.90GiB(75.65%) wps: 100 mfu: 0.58%
[rank0]:2024-08-21 08:46:07,756 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-21 08:46:19,266 - root - INFO - step: 10 loss: 10.8650 memory: 82.42GiB(86.72%) wps: 6,405 mfu: 37.51%
[rank0]:2024-08-21 08:46:32,000 - root - INFO - step: 20 loss: 9.1536 memory: 82.42GiB(86.72%) wps: 6,434 mfu: 37.68%
[rank0]:2024-08-21 08:46:44,771 - root - INFO - step: 30 loss: 8.1057 memory: 82.42GiB(86.72%) wps: 6,416 mfu: 37.57%
Baseline:
[rank0]:2024-08-21 08:47:49,505 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2024-08-21 08:47:49,829 - root - INFO - Compiling each TransformerBlock with torch.compile
[rank0]:2024-08-21 08:47:49,892 - root - INFO - Applied FSDP to the model
[rank0]:NCCL version 2.21.5+cuda12.0
[rank0]:2024-08-21 08:48:01,630 - root - INFO - GPU memory usage for model: 3.78GiB(3.98%)
[rank0]:2024-08-21 08:48:01,631 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]: warnings.warn(
[rank0]:2024-08-21 08:48:56,359 - root - INFO - step: 1 loss: 12.2556 memory: 67.73GiB(71.26%) wps: 150 mfu: 0.88%
[rank0]:2024-08-21 08:48:56,359 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-21 08:49:08,017 - root - INFO - step: 10 loss: 10.8891 memory: 71.64GiB(75.38%) wps: 6,324 mfu: 37.03%
[rank0]:2024-08-21 08:49:20,989 - root - INFO - step: 20 loss: 9.0440 memory: 71.64GiB(75.38%) wps: 6,316 mfu: 36.99%
[rank0]:2024-08-21 08:49:33,984 - root - INFO - step: 30 loss: 8.0371 memory: 71.64GiB(75.38%) wps: 6,305 mfu: 36.92%
[rank0]:2024-08-21 08:49:46,985 - root - INFO - step: 40 loss: 7.4550 memory: 71.64GiB(75.38%) wps: 6,302 mfu: 36.90%
[rank0]:2024-08-21 08:50:00,009 - root - INFO - step: 50 loss: 7.2264 memory: 71.64GiB(75.38%) wps: 6,290 mfu: 36.84%
Moving .float() into the CE loss and compiling it (P1539164756):
[rank0]:2024-08-21 08:55:16,471 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2024-08-21 08:55:16,786 - root - INFO - Compiling each TransformerBlock with torch.compile
[rank0]:2024-08-21 08:55:16,847 - root - INFO - Applied FSDP to the model
[rank0]:NCCL version 2.21.5+cuda12.0
[rank0]:2024-08-21 08:55:29,108 - root - INFO - GPU memory usage for model: 3.78GiB(3.98%)
[rank0]:2024-08-21 08:55:29,110 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]: warnings.warn(
[rank0]:2024-08-21 08:56:21,884 - root - INFO - step: 1 loss: 12.2164 memory: 58.18GiB(61.22%) wps: 155 mfu: 0.91%
[rank0]:2024-08-21 08:56:21,884 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-21 08:56:33,351 - root - INFO - step: 10 loss: 10.8179 memory: 66.01GiB(69.46%) wps: 6,430 mfu: 37.65%
[rank0]:2024-08-21 08:56:46,101 - root - INFO - step: 20 loss: 9.0846 memory: 66.01GiB(69.46%) wps: 6,426 mfu: 37.63%
[rank0]:2024-08-21 08:56:58,879 - root - INFO - step: 30 loss: 8.0600 memory: 66.01GiB(69.46%) wps: 6,412 mfu: 37.55%
[rank0]:2024-08-21 08:57:11,658 - root - INFO - step: 40 loss: 7.4393 memory: 66.01GiB(69.46%) wps: 6,411 mfu: 37.54%
[rank0]:2024-08-21 08:57:24,460 - root - INFO - step: 50 loss: 7.1899 memory: 66.01GiB(69.46%) wps: 6,400 mfu: 37.48%
[rank0]:2024-08-21 08:57:37,291 - root - INFO - step: 60 loss: 7.0205 memory: 66.01GiB(69.46%) wps: 6,386 mfu: 37.39%