lightning-thunder
lightning-thunder copied to clipboard
[benchmark] migrate to fsdp/ddp after jit, from fsdp/ddp before jit
Llama-2-7b-hf & fsdp
container of 20240804. 8 H100 80GB HBM3.
command: torchrun --nproc_per_node=8 thunder/benchmarks/benchmark_litgpt.py --compile=thunder_inductor_cat --model_name=Llama-2-7b-hf --distributed_mode=fsdp --shard_mode=zero2 --bucketing_mode=none
| main 5cc3011 | pr 3869547 | |
|---|---|---|
| token/s | 85765.16 | 85695.95 |
| tokens/s/GPU | 10720.65 | 10711.99 |
| Memory | 42.06 GB | 42.06 GB |
cc @mpatel31415
related:
- https://github.com/Lightning-AI/lightning-thunder/issues/478
There may already be a plan for this since this is still a draft and apologies if I am jumping ahead but given the big increase in memory consumption with this change and many workloads from @mpatel31415's benchmarks already failing with OOM error, do you think we can make this change optional?
I took two memory snapshots of both to see if the memory increase comes from the training step.
main
pr
It seems that the difference comes from outside of the training loop, meaning sharding, materialization, and so on of fsdp(jit(model)) and jit(fsdp(model))
Just so I understand the snapshot above, the blue markers are memory allocation during the training step right? Do we know the reason why fsdp(jit(model)) has higher consumption? Is it because once the model is jit, it cannot shard all the parameters as some get hidden inside fusion blocks? Or perhaps, fsdp after jit means it has to AllGather larger buckets due to fusion blocks?
Just so I understand the snapshot above, the blue markers are memory allocation during the training step right? Do we know the reason why
fsdp(jit(model))has higher consumption? Is it because once the model is jit, it cannot shard all the parameters as some get hidden inside fusion blocks? Or perhaps, fsdp after jit means it has to AllGather larger buckets due to fusion blocks?
The piled layers that's spanning horizontally in the bottom capture seem to indicate that materialize and/or sharding of fsdp(jit(model)) leaves some parameters dangling somewhere while they are not used.
I do think it's because fsdp(jit(model)) copies params before sharding, after materializing while jit(fsdp(model)) directly applies sharding to params.
#932 would be a workaround when a model is initialized with torch.device("meta").