lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

[benchmark] migrate to fsdp/ddp after jit, from fsdp/ddp before jit

Open crcrpar opened this issue 1 year ago • 4 comments

Llama-2-7b-hf & fsdp

container of 20240804. 8 H100 80GB HBM3. command: torchrun --nproc_per_node=8 thunder/benchmarks/benchmark_litgpt.py --compile=thunder_inductor_cat --model_name=Llama-2-7b-hf --distributed_mode=fsdp --shard_mode=zero2 --bucketing_mode=none

main 5cc3011 pr 3869547
token/s 85765.16 85695.95
tokens/s/GPU 10720.65 10711.99
Memory 42.06 GB 42.06 GB

cc @mpatel31415

related:

  • https://github.com/Lightning-AI/lightning-thunder/issues/478

crcrpar avatar Aug 05 '24 06:08 crcrpar

There may already be a plan for this since this is still a draft and apologies if I am jumping ahead but given the big increase in memory consumption with this change and many workloads from @mpatel31415's benchmarks already failing with OOM error, do you think we can make this change optional?

parthmannan avatar Aug 05 '24 06:08 parthmannan

I took two memory snapshots of both to see if the memory increase comes from the training step.

main image

pr image

It seems that the difference comes from outside of the training loop, meaning sharding, materialization, and so on of fsdp(jit(model)) and jit(fsdp(model))

crcrpar avatar Aug 05 '24 12:08 crcrpar

Just so I understand the snapshot above, the blue markers are memory allocation during the training step right? Do we know the reason why fsdp(jit(model)) has higher consumption? Is it because once the model is jit, it cannot shard all the parameters as some get hidden inside fusion blocks? Or perhaps, fsdp after jit means it has to AllGather larger buckets due to fusion blocks?

parthmannan avatar Aug 07 '24 00:08 parthmannan

Just so I understand the snapshot above, the blue markers are memory allocation during the training step right? Do we know the reason why fsdp(jit(model)) has higher consumption? Is it because once the model is jit, it cannot shard all the parameters as some get hidden inside fusion blocks? Or perhaps, fsdp after jit means it has to AllGather larger buckets due to fusion blocks?

The piled layers that's spanning horizontally in the bottom capture seem to indicate that materialize and/or sharding of fsdp(jit(model)) leaves some parameters dangling somewhere while they are not used.

I do think it's because fsdp(jit(model)) copies params before sharding, after materializing while jit(fsdp(model)) directly applies sharding to params. #932 would be a workaround when a model is initialized with torch.device("meta").

crcrpar avatar Aug 07 '24 03:08 crcrpar