lightning-thunder
lightning-thunder copied to clipboard
Implement `no_sync` for `thunder.distributed.fsdp` (PR2457)
tldr
Enables no_sync for thunder.jit(thunder.distributed.fsdp(model)). The accompanied changes are:
- new argument of
return_none_instead_of_gradsofThunderFunction.forward- This could be eliminated once a
TraceCtx's bound symbols are not deleted even if it just returns one or moreNones
- This could be eliminated once a
- removal of
no_synccheck before applyingdist_prims.synchronizeto args and kwargs- FSDP's forward needs this prim for its param AllGather
- [ddp]
visitor_transformremovesdist_prims.all_reduce,dist_prims.wait, and preaveraging whenno_sync - [fsdp]
visitor_transformremoves comms and putsdist_prims.stash_grad_for_fsdpand optional paramAllGatherwhenno_sync- The generated trace and its executable python code return unsynchronized unsharded gradients.
- The prim's implementation accumulates the grads as
param._thunder_fsdp_unsharded_grad. ThunderFunction'sbackwardreturnsNones instead of such grads to avoid shape mismatch between params and unsharded grads.
as of 93bfa4ae454326d6d3e858bbed21beed2d97809a
- llama-2-7b-hf
- world size 8 H100s
- micro batch size 1
- global batch size 32
- gradient accumulation 4
- no bucketing (of AllGather and ReduceScatter)
zero2
w/ no_sync |
w/o no_sync |
|
|---|---|---|
| tokens/sec | 81902.8 | 79553.6 |
| memory consumption [GB] | 65.6 | 40.4 |
zero3
need to check the perf deg...
w/ no_sync |
w/o no_sync |
|
|---|---|---|
| tokens/sec | 74267.9 | 75113.3 |
| memory consumption [GB] | 52.6 | 27.2 |