torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[Not for land] Added changes for GPT-2 perf

Open awgu opened this issue 1 year ago • 5 comments

Stack from ghstack (oldest at bottom):

  • -> #533
  • #532

Credit: @felipemello1 for the previous token chunked cross entropy Credit: @Chillee for the new token chunked cross entropy

Running on 4xH100s: Without these changes (torch.compile), the max local batch size is 5:

[rank0]:2024-08-19 11:10:26,196 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:2024-08-19 11:10:33,811 - root - INFO - step:  1  loss: 12.2365  memory: 81.67GiB(85.93%)  wps: 5,380  mfu: 1.09%
[rank0]:2024-08-19 11:10:33,811 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:10:37,109 - root - INFO - step: 10  loss: 12.1951  memory: 81.67GiB(85.93%)  wps: 111,770  mfu: 22.68%
[rank0]:2024-08-19 11:10:40,777 - root - INFO - step: 20  loss: 11.9455  memory: 81.67GiB(85.93%)  wps: 111,714  mfu: 22.67%
[rank0]:2024-08-19 11:10:44,428 - root - INFO - step: 30  loss: 11.0407  memory: 81.67GiB(85.93%)  wps: 112,194  mfu: 22.76%
[rank0]:2024-08-19 11:10:48,083 - root - INFO - step: 40  loss:  9.9520  memory: 81.67GiB(85.93%)  wps: 112,109  mfu: 22.75%
[rank0]:2024-08-19 11:10:51,734 - root - INFO - step: 50  loss:  9.3392  memory: 81.67GiB(85.93%)  wps: 112,218  mfu: 22.77%
[rank0]:2024-08-19 11:10:55,386 - root - INFO - step: 60  loss:  8.7255  memory: 81.67GiB(85.93%)  wps: 112,198  mfu: 22.77%
[rank0]:2024-08-19 11:10:59,037 - root - INFO - step: 70  loss:  8.1659  memory: 81.67GiB(85.93%)  wps: 112,234  mfu: 22.77%
[rank0]:2024-08-19 11:11:02,701 - root - INFO - step: 80  loss:  7.8037  memory: 81.67GiB(85.93%)  wps: 111,802  mfu: 22.68%
[rank0]:2024-08-19 11:11:06,361 - root - INFO - step: 90  loss:  7.5327  memory: 81.67GiB(85.93%)  wps: 111,937  mfu: 22.71%
[rank0]:2024-08-19 11:11:10,026 - root - INFO - step: 100  loss:  7.3730  memory: 81.67GiB(85.93%)  wps: 111,803  mfu: 22.69%
Without these changes, no compile

Without these changes (no torch.compile), local batch size 5:

[rank0]:2024-08-19 14:24:32,150 - root - INFO - Training starts at step 1, with local batch size 5, global batch size 20, sequence length 8192, total steps 100 (warmup 200)
[rank0]:2024-08-19 14:24:38,558 - root - INFO - step:  1  loss: 12.2581  memory: 86.47GiB(90.99%)  wps: 6,393  mfu: 1.30%
[rank0]:2024-08-19 14:24:38,558 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 14:24:42,308 - root - INFO - step: 10  loss: 12.2099  memory: 86.48GiB(90.99%)  wps: 98,305  mfu: 19.95%
[rank0]:2024-08-19 14:24:46,482 - root - INFO - step: 20  loss: 11.9421  memory: 86.48GiB(90.99%)  wps: 98,230  mfu: 19.93%
[rank0]:2024-08-19 14:24:50,648 - root - INFO - step: 30  loss: 11.0090  memory: 86.48GiB(90.99%)  wps: 98,435  mfu: 19.97%
[rank0]:2024-08-19 14:24:54,788 - root - INFO - step: 40  loss:  9.9780  memory: 86.48GiB(90.99%)  wps: 99,064  mfu: 20.10%
[rank0]:2024-08-19 14:24:58,936 - root - INFO - step: 50  loss:  9.3572  memory: 86.48GiB(90.99%)  wps: 98,813  mfu: 20.05%
[rank0]:2024-08-19 14:25:03,181 - root - INFO - step: 60  loss:  8.7479  memory: 86.48GiB(90.99%)  wps: 96,567  mfu: 19.59%
[rank0]:2024-08-19 14:25:07,339 - root - INFO - step: 70  loss:  8.1769  memory: 86.48GiB(90.99%)  wps: 98,604  mfu: 20.01%
[rank0]:2024-08-19 14:25:11,497 - root - INFO - step: 80  loss:  7.8070  memory: 86.48GiB(90.99%)  wps: 98,579  mfu: 20.00%
[rank0]:2024-08-19 14:25:15,649 - root - INFO - step: 90  loss:  7.5329  memory: 86.48GiB(90.99%)  wps: 98,743  mfu: 20.04%
[rank0]:2024-08-19 14:25:19,798 - root - INFO - step: 100  loss:  7.3700  memory: 86.48GiB(90.99%)  wps: 98,818  mfu: 20.05%

With these changes (torch.compile), local batch size 32:

[rank0]:2024-09-06 19:48:58,342 - root - INFO - Training starts at step 1, with local batch size 32, global batch size 128, sequence length 8192, total steps 50 (warmup 200)
[rank0]:2024-09-06 19:49:08,904 - root - INFO - step:  1  loss: 12.2442  memory: 79.40GiB(83.54%)  wps: 24,819  mfu: 5.04%
[rank0]:2024-09-06 19:49:08,904 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-09-06 19:49:23,127 - root - INFO - step: 10  loss: 12.1998  memory: 80.81GiB(85.03%)  wps: 165,880  mfu: 33.66%
[rank0]:2024-09-06 19:49:38,946 - root - INFO - step: 20  loss: 11.9284  memory: 80.81GiB(85.03%)  wps: 165,732  mfu: 33.63%
[rank0]:2024-09-06 19:49:54,764 - root - INFO - step: 30  loss: 10.9587  memory: 80.81GiB(85.03%)  wps: 165,733  mfu: 33.63%
[rank0]:2024-09-06 19:50:10,566 - root - INFO - step: 40  loss:  9.8493  memory: 80.81GiB(85.03%)  wps: 165,904  mfu: 33.66%
[rank0]:2024-09-06 19:50:26,973 - root - INFO - step: 50  loss:  9.2317  memory: 80.81GiB(85.03%)  wps: 159,786  mfu: 32.42%
Old Results

With these changes, we can use local batch size 16:

[rank0]:2024-08-19 11:16:09,534 - root - INFO - Training starts at step 1, with local batch size 16, global batch size 64, sequence length 8192, total steps 100 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-19 11:16:15,523 - root - INFO - step:  1  loss: 12.2386  memory: 72.29GiB(76.06%)  wps: 21,887  mfu: 4.44%
[rank0]:2024-08-19 11:16:15,523 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-19 11:16:22,538 - root - INFO - step: 10  loss: 12.1966  memory: 72.30GiB(76.07%)  wps: 168,174  mfu: 34.12%
[rank0]:2024-08-19 11:16:30,332 - root - INFO - step: 20  loss: 11.9229  memory: 72.30GiB(76.07%)  wps: 168,196  mfu: 34.13%
[rank0]:2024-08-19 11:16:38,129 - root - INFO - step: 30  loss: 10.9399  memory: 72.30GiB(76.07%)  wps: 168,144  mfu: 34.12%
[rank0]:2024-08-19 11:16:45,937 - root - INFO - step: 40  loss:  9.8742  memory: 72.30GiB(76.07%)  wps: 167,898  mfu: 34.07%
[rank0]:2024-08-19 11:16:53,734 - root - INFO - step: 50  loss:  9.2517  memory: 72.30GiB(76.07%)  wps: 168,130  mfu: 34.11%
[rank0]:2024-08-19 11:17:01,518 - root - INFO - step: 60  loss:  8.6441  memory: 72.30GiB(76.07%)  wps: 168,435  mfu: 34.18%
[rank0]:2024-08-19 11:17:09,279 - root - INFO - step: 70  loss:  8.0827  memory: 72.30GiB(76.07%)  wps: 168,927  mfu: 34.28%
[rank0]:2024-08-19 11:17:17,047 - root - INFO - step: 80  loss:  7.7330  memory: 72.30GiB(76.07%)  wps: 168,772  mfu: 34.24%
[rank0]:2024-08-19 11:17:25,139 - root - INFO - step: 90  loss:  7.4835  memory: 72.30GiB(76.07%)  wps: 162,008  mfu: 32.87%
[rank0]:2024-08-19 11:17:32,944 - root - INFO - step: 100  loss:  7.3274  memory: 72.30GiB(76.07%)  wps: 167,963  mfu: 34.08%

22.7% MFU -> 34.1% MFU

awgu avatar Aug 19 '24 18:08 awgu

Great results! Consider using .split(8192//32, dim=1) instead of .chunk(16).

There was a huge difference in reserved memory in my experiments (no change in active memory though) image

felipemello1 avatar Aug 19 '24 18:08 felipemello1

FYI, compiling loss + model together should yield much better results than compiling the model alone, if this is whats happening.

instead of doing: torch.compile(model)

do something like:

@torch.compile()
def loss_step(input,label):
       output = self.model(input)
       loss = calculate_loss(output, label)
       loss.backward()

What we found is that using torch.compile on the cross entropy loss alone has great memory benefits (but not better than chunked): https://fb.workplace.com/groups/257735836456307/permalink/708422718054281/ And compiling model + loss together almost doubles toks/second: https://github.com/pytorch/torchtune/issues/1228#issuecomment-2277151232

But the best results for us is compiling only the model + using the chunked cross entropy. If we compile everything, then the results of chunked cross entropy are lost.

felipemello1 avatar Aug 19 '24 21:08 felipemello1

If I try to compile both the output linear and cross entropy loss together instead of just compiling the cross entropy loss, I get OOMs at the same batch size.

awgu avatar Aug 21 '24 15:08 awgu

My uneducated guess is that the optimizations they made for CrossEntropyLoss accounts only for the loss being compiled on its own. Details of their implementation here: https://fb.workplace.com/groups/257735836456307/permalink/708422718054281/

felipemello1 avatar Aug 21 '24 15:08 felipemello1

Llama3-8B

With these changes:

[rank0]:2024-08-21 08:44:32,865 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2024-08-21 08:44:32,897 - root - INFO - Compiling each TransformerBlock with torch.compile
[rank0]:2024-08-21 08:44:32,953 - root - INFO - Applied FSDP to the model
[rank0]:NCCL version 2.21.5+cuda12.0
[rank0]:2024-08-21 08:44:45,742 - root - INFO - GPU memory usage for model: 3.99GiB(4.20%)
[rank0]:2024-08-21 08:44:45,743 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-21 08:46:07,756 - root - INFO - step:  1  loss: 12.2044  memory: 71.90GiB(75.65%)  wps: 100  mfu: 0.58%
[rank0]:2024-08-21 08:46:07,756 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-21 08:46:19,266 - root - INFO - step: 10  loss: 10.8650  memory: 82.42GiB(86.72%)  wps: 6,405  mfu: 37.51%
[rank0]:2024-08-21 08:46:32,000 - root - INFO - step: 20  loss:  9.1536  memory: 82.42GiB(86.72%)  wps: 6,434  mfu: 37.68%
[rank0]:2024-08-21 08:46:44,771 - root - INFO - step: 30  loss:  8.1057  memory: 82.42GiB(86.72%)  wps: 6,416  mfu: 37.57%

Baseline:

[rank0]:2024-08-21 08:47:49,505 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2024-08-21 08:47:49,829 - root - INFO - Compiling each TransformerBlock with torch.compile
[rank0]:2024-08-21 08:47:49,892 - root - INFO - Applied FSDP to the model
[rank0]:NCCL version 2.21.5+cuda12.0
[rank0]:2024-08-21 08:48:01,630 - root - INFO - GPU memory usage for model: 3.78GiB(3.98%)
[rank0]:2024-08-21 08:48:01,631 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-21 08:48:56,359 - root - INFO - step:  1  loss: 12.2556  memory: 67.73GiB(71.26%)  wps: 150  mfu: 0.88%
[rank0]:2024-08-21 08:48:56,359 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-21 08:49:08,017 - root - INFO - step: 10  loss: 10.8891  memory: 71.64GiB(75.38%)  wps: 6,324  mfu: 37.03%
[rank0]:2024-08-21 08:49:20,989 - root - INFO - step: 20  loss:  9.0440  memory: 71.64GiB(75.38%)  wps: 6,316  mfu: 36.99%
[rank0]:2024-08-21 08:49:33,984 - root - INFO - step: 30  loss:  8.0371  memory: 71.64GiB(75.38%)  wps: 6,305  mfu: 36.92%
[rank0]:2024-08-21 08:49:46,985 - root - INFO - step: 40  loss:  7.4550  memory: 71.64GiB(75.38%)  wps: 6,302  mfu: 36.90%
[rank0]:2024-08-21 08:50:00,009 - root - INFO - step: 50  loss:  7.2264  memory: 71.64GiB(75.38%)  wps: 6,290  mfu: 36.84%

Moving .float() into the CE loss and compiling it (P1539164756):

[rank0]:2024-08-21 08:55:16,471 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters
[rank0]:2024-08-21 08:55:16,786 - root - INFO - Compiling each TransformerBlock with torch.compile
[rank0]:2024-08-21 08:55:16,847 - root - INFO - Applied FSDP to the model
[rank0]:NCCL version 2.21.5+cuda12.0
[rank0]:2024-08-21 08:55:29,108 - root - INFO - GPU memory usage for model: 3.78GiB(3.98%)
[rank0]:2024-08-21 08:55:29,110 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 8, sequence length 8192, total steps 1000 (warmup 200)
[rank0]:/data/users/andgu/pytorch/torch/_inductor/lowering.py:1673: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]:  warnings.warn(
[rank0]:2024-08-21 08:56:21,884 - root - INFO - step:  1  loss: 12.2164  memory: 58.18GiB(61.22%)  wps: 155  mfu: 0.91%
[rank0]:2024-08-21 08:56:21,884 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:2024-08-21 08:56:33,351 - root - INFO - step: 10  loss: 10.8179  memory: 66.01GiB(69.46%)  wps: 6,430  mfu: 37.65%
[rank0]:2024-08-21 08:56:46,101 - root - INFO - step: 20  loss:  9.0846  memory: 66.01GiB(69.46%)  wps: 6,426  mfu: 37.63%
[rank0]:2024-08-21 08:56:58,879 - root - INFO - step: 30  loss:  8.0600  memory: 66.01GiB(69.46%)  wps: 6,412  mfu: 37.55%
[rank0]:2024-08-21 08:57:11,658 - root - INFO - step: 40  loss:  7.4393  memory: 66.01GiB(69.46%)  wps: 6,411  mfu: 37.54%
[rank0]:2024-08-21 08:57:24,460 - root - INFO - step: 50  loss:  7.1899  memory: 66.01GiB(69.46%)  wps: 6,400  mfu: 37.48%
[rank0]:2024-08-21 08:57:37,291 - root - INFO - step: 60  loss:  7.0205  memory: 66.01GiB(69.46%)  wps: 6,386  mfu: 37.39%

awgu avatar Aug 21 '24 15:08 awgu