llm.c icon indicating copy to clipboard operation
llm.c copied to clipboard

Zero Redundancy Optimizer - Stage1

Open chinthysl opened this issue 9 months ago • 4 comments

To train much larger model variations (2B, 7B, etc), we need larger GPU memory allocations for parameters, optimizer states, and gradients. Zero Redundancy Optimizer introduce the methodology to shard these parameters across processes to save GPU memory in multi gpu training. This PR introduces changes for Zero Optimization Stage 1 implementation.

I tested using 1, 2, and 8 processes in DGX-A100. AdamW optimizer state memory is linearly decreased (2X and 8X) tok/s almost linearly increases (2X and 8X) with latency addition around 10-20ms from nccl all-gather operations.

mpirun -np 1 train_gpt2cu -i data/TinyStories -z 1
+-----------------------+----------------------------------------------------+
| Zero Stage1 is enabled                                                     |
| num_processes         | 1                                                  |
| zero_stage            | 1                                                  |
+-----------------------+----------------------------------------------------+
num_parameters: 124475904 ==> bytes: 248951808
allocated 237 MiB for model parameters
allocated 2853 MiB for activations
val loss 2.373094
allocated 237 MiB for parameter gradients
allocated 126 MiB for activation gradients
allocated 474 MiB for AdamW optimizer state m
allocated 474 MiB for AdamW optimizer state v
step    1/225989: train loss 2.378581 (acc 2.378581) (89.811098 ms, 45606 tok/s)
step    2/225989: train loss 3.015817 (acc 3.015817) (86.563665 ms, 47317 tok/s)
step    3/225989: train loss 2.499960 (acc 2.499960) (87.369672 ms, 46881 tok/s)
step    4/225989: train loss 2.261637 (acc 2.261637) (70.757411 ms, 57887 tok/s)
step    5/225989: train loss 2.405512 (acc 2.405512) (73.156067 ms, 55989 tok/s)
step    6/225989: train loss 2.136998 (acc 2.136998) (73.465045 ms, 55754 tok/s)
step    7/225989: train loss 2.171301 (acc 2.171301) (73.596110 ms, 55655 tok/s)
step    8/225989: train loss 2.074431 (acc 2.074431) (73.237470 ms, 55927 tok/s)
step    9/225989: train loss 2.073056 (acc 2.073056) (73.568559 ms, 55675 tok/s)
step   10/225989: train loss 2.066257 (acc 2.066257) (73.369085 ms, 55827 tok/s)
step   11/225989: train loss 2.010005 (acc 2.010005) (84.799515 ms, 48302 tok/s)
step   12/225989: train loss 2.003811 (acc 2.003811) (87.137056 ms, 47006 tok/s)
step   13/225989: train loss 2.020144 (acc 2.020144) (87.132569 ms, 47008 tok/s)
step   14/225989: train loss 2.029459 (acc 2.029459) (87.166792 ms, 46990 tok/s)
step   15/225989: train loss 1.975009 (acc 1.975009) (87.325419 ms, 46905 tok/s)
step   16/225989: train loss 2.031827 (acc 2.031827) (86.789867 ms, 47194 tok/s)
step   17/225989: train loss 1.907860 (acc 1.907860) (87.003316 ms, 47078 tok/s)
step   18/225989: train loss 1.971371 (acc 1.971371) (87.260207 ms, 46940 tok/s)
step   19/225989: train loss 1.925187 (acc 1.925187) (77.354431 ms, 52951 tok/s)
step   20/225989: train loss 1.930509 (acc 1.930509) (71.431011 ms, 57342 tok/s)
val loss 1.924512

mpirun -np 2 train_gpt2cu -i data/TinyStories -z 1
+-----------------------+----------------------------------------------------+
| Zero Stage1 is enabled                                                     |
| num_processes         | 2                                                  |
| zero_stage            | 1                                                  |
+-----------------------+----------------------------------------------------+
num_parameters: 124475904 ==> bytes: 248951808
allocated 237 MiB for model parameters
allocated 2853 MiB for activations
val loss 2.374750
allocated 237 MiB for parameter gradients
allocated 126 MiB for activation gradients
allocated 237 MiB for AdamW optimizer state m
allocated 237 MiB for AdamW optimizer state v
step    1/112994: train loss 2.378581 (acc 2.379556) (96.999654 ms, 84453 tok/s)
step    2/112994: train loss 3.186077 (acc 3.160388) (104.673116 ms, 78262 tok/s)
step    3/112994: train loss 2.607240 (acc 2.525164) (108.197008 ms, 75713 tok/s)
step    4/112994: train loss 2.319467 (acc 2.267938) (101.117657 ms, 81014 tok/s)
step    5/112994: train loss 2.208531 (acc 2.210081) (96.576283 ms, 84824 tok/s)
step    6/112994: train loss 2.145689 (acc 2.153934) (108.735815 ms, 75338 tok/s)
step    7/112994: train loss 2.149342 (acc 2.152325) (91.689472 ms, 89345 tok/s)
step    8/112994: train loss 2.101048 (acc 2.126100) (95.046731 ms, 86189 tok/s)
step    9/112994: train loss 2.018039 (acc 2.066898) (86.529522 ms, 94672 tok/s)
step   10/112994: train loss 2.020076 (acc 2.028252) (85.158306 ms, 96197 tok/s)
step   11/112994: train loss 2.110162 (acc 2.072971) (80.118961 ms, 102247 tok/s)
step   12/112994: train loss 2.041743 (acc 2.021564) (81.944045 ms, 99970 tok/s)
step   13/112994: train loss 2.076852 (acc 2.057523) (84.026460 ms, 97493 tok/s)
step   14/112994: train loss 2.002669 (acc 2.026710) (84.166291 ms, 97331 tok/s)
step   15/112994: train loss 1.962275 (acc 1.973516) (101.858121 ms, 80425 tok/s)
step   16/112994: train loss 1.914360 (acc 1.936451) (108.651898 ms, 75396 tok/s)
step   17/112994: train loss 1.908354 (acc 1.893307) (103.965212 ms, 78795 tok/s)
step   18/112994: train loss 1.938197 (acc 1.941555) (95.846327 ms, 85470 tok/s)
step   19/112994: train loss 1.809096 (acc 1.792823) (91.729648 ms, 89305 tok/s)
step   20/112994: train loss 1.947107 (acc 1.904431) (91.292329 ms, 89733 tok/s)
val loss 1.885905

mpirun -np 8 train_gpt2cu -i data/TinyStories -z 1
+-----------------------+----------------------------------------------------+
| Zero Stage1 is enabled                                                     |
| num_processes         | 8                                                  |
| zero_stage            | 1                                                  |
+-----------------------+----------------------------------------------------+
num_parameters: 124475904 ==> bytes: 248951808
allocated 237 MiB for model parameters
allocated 2853 MiB for activations
val loss 2.385154
allocated 237 MiB for parameter gradients
allocated 126 MiB for activation gradients
allocated 59 MiB for AdamW optimizer state m
allocated 59 MiB for AdamW optimizer state v
step    1/28248: train loss 2.378581 (acc 2.384222) (115.159932 ms, 284543 tok/s)
step    2/28248: train loss 3.294697 (acc 3.265185) (111.280736 ms, 294462 tok/s)
step    3/28248: train loss 2.389775 (acc 2.421525) (91.648827 ms, 357538 tok/s)
step    4/28248: train loss 2.273024 (acc 2.249920) (99.208216 ms, 330295 tok/s)
step    5/28248: train loss 2.176223 (acc 2.169333) (102.035945 ms, 321141 tok/s)
step    6/28248: train loss 2.188320 (acc 2.159471) (135.788975 ms, 241315 tok/s)
step    7/28248: train loss 2.046117 (acc 2.089896) (131.576817 ms, 249040 tok/s)
step    8/28248: train loss 2.023791 (acc 2.103906) (131.244736 ms, 249670 tok/s)
step    9/28248: train loss 1.938758 (acc 2.045031) (131.523668 ms, 249141 tok/s)
step   10/28248: train loss 1.920710 (acc 1.996141) (122.257226 ms, 268025 tok/s)
step   11/28248: train loss 1.987167 (acc 2.006098) (134.310771 ms, 243971 tok/s)
step   12/28248: train loss 2.015834 (acc 1.933221) (107.914139 ms, 303648 tok/s)
step   13/28248: train loss 1.963828 (acc 1.941307) (98.656394 ms, 332142 tok/s)
step   14/28248: train loss 1.917733 (acc 1.932031) (96.455467 ms, 339721 tok/s)
step   15/28248: train loss 1.994990 (acc 1.883007) (101.058786 ms, 324246 tok/s)
step   16/28248: train loss 1.859246 (acc 1.889248) (114.783167 ms, 285477 tok/s)
step   17/28248: train loss 1.922480 (acc 1.859198) (118.057020 ms, 277560 tok/s)
step   18/28248: train loss 1.726868 (acc 1.858811) (127.837003 ms, 256326 tok/s)
step   19/28248: train loss 1.798945 (acc 1.847258) (131.560066 ms, 249072 tok/s)
step   20/28248: train loss 1.794839 (acc 1.844039) (119.471586 ms, 274274 tok/s)
val loss 1.826622

chinthysl avatar Apr 30 '24 09:04 chinthysl

Great, excited to get here! The optimization is still in a bit of a flux, esp around 1) gradient accumulation and 2) gradient clipping. I want to get those in first before we reach for sharding, and probably sharding will be impacted by them. e.g. especially thinking about gradient clipping, which is a reduction on all gradients.

karpathy avatar Apr 30 '24 15:04 karpathy

Hi @chinthysl heads up that I just merged a PR to (optionally) keep master weights in fp32. I think this impacts this PR

https://github.com/karpathy/llm.c/pull/328

eager to merge this one though!

karpathy avatar May 01 '24 23:05 karpathy

@karpathy Thank you for reviewing. Eager to have a look at 1) gradient accumulation and 2) gradient clipping to see if I can contribute. I refactored the PR to cater in #328 changes and review suggestions from @PeterZhizhin.

chinthysl avatar May 02 '24 02:05 chinthysl

@karpathy I’ve fixed the previous CI issues. Waiting until multi-gpu hanging issue get resolved.

chinthysl avatar May 07 '24 03:05 chinthysl

@karpathy I added changes to shard master weights and remove previous unnecessary all-gather function for master weights. ty.

chinthysl avatar May 13 '24 06:05 chinthysl

Some notes on this PR from exploration on my GPU box

My current default "go to" run is this 124M model configuration:

make train_gpt2cu USE_CUDNN=1
mpirun -np 4 ./train_gpt2cu -i data/TinyStories -v 250 -s 250 -g 144 -o stories.log -b 32

When I run this with additionally zero turned off (-z 0) or on (-z 1) I see:

-z 0: 287ms/step, 456K tok/s, 16661 MiB/GPU
-z 1: 327ms/step, 400K tok/s, 15593 MiB/GPU

So we accumulate latency here but don't get big memory gains, because most of the memory at this model size is activations. E.g. with -z 1 this looks like:

allocated 237 MiB for model parameters
allocated 13629 MiB for activations
allocated 237 MiB for parameter gradients
allocated 240 MiB for activation gradients
allocated 118 MiB for AdamW optimizer state m
allocated 118 MiB for AdamW optimizer state v
allocated 118 MiB for master copy of params

i.e. the majority of memory by far (13.6GB) is activations.

But now trying with the actual GPT-2 (1.6B) model:

# actual GPT-2 model:
mpirun -np 4 ./train_gpt2cu -z 1 -e gpt2_1558M_bf16.bin -i data/TinyStories -v 250 -s 250 -g 144 -o stories.log -b 4

# note this is batch size only 4, to make things fit
# -z 0: 1520ms/step, 10,745 tok/s, 35975 MiB/GPU
# -z 1: 2110ms/step, 7,757 tok/s, 22601 MiB/GPU
# => 37% memory reduction

So the latency goes up quite a bit, but we're now saving a good amount of memory. These memory savings now allow us to crank up the batch size, and get a net token throughput win, e.g. increasing batch size to 8 or even 10, which almost exactly maxes out VRAM on my A100 40GB:

# -z 1 -b 8: 2347ms/step, 13,930 tok/s, 33801 MiB/GPU
# -z 1 -b 10: 16,537 tok/s, 39475 MiB/GPU

So the original batch size that fits of 4 can go up to 10, and we go from 10,745 tok/s (-b 4) up to 16,537 tok/s (-b 10). i.e. ~54% token training throughput

pretty cool!!

All timings on my 4X A100 40GB GPU node, (only PCIe interconnect! of note is that this is a very slow and non-standard interconnect and more modern boxes should go a lot faster here)

karpathy avatar May 13 '24 09:05 karpathy

Ok finally had time to step through in detail, LGTM ty.

karpathy avatar May 13 '24 20:05 karpathy