lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Implement `no_sync` for `thunder.distributed.fsdp` (PR2457)

Open crcrpar opened this issue 1 year ago • 0 comments

tldr

Enables no_sync for thunder.jit(thunder.distributed.fsdp(model)). The accompanied changes are:

  • new argument of return_none_instead_of_grads of ThunderFunction.forward
    • This could be eliminated once a TraceCtx's bound symbols are not deleted even if it just returns one or more Nones
  • removal of no_sync check before applying dist_prims.synchronize to args and kwargs
    • FSDP's forward needs this prim for its param AllGather
    • [ddp] visitor_transform removes dist_prims.all_reduce, dist_prims.wait, and preaveraging when no_sync
    • [fsdp] visitor_transform removes comms and puts dist_prims.stash_grad_for_fsdp and optional param AllGather when no_sync
      • The generated trace and its executable python code return unsynchronized unsharded gradients.
      • The prim's implementation accumulates the grads as param._thunder_fsdp_unsharded_grad.
      • ThunderFunction's backward returns Nones instead of such grads to avoid shape mismatch between params and unsharded grads.

as of 93bfa4ae454326d6d3e858bbed21beed2d97809a

  • llama-2-7b-hf
  • world size 8 H100s
  • micro batch size 1
  • global batch size 32
  • gradient accumulation 4
  • no bucketing (of AllGather and ReduceScatter)

zero2

w/ no_sync w/o no_sync
tokens/sec 81902.8 79553.6
memory consumption [GB] 65.6 40.4

zero3

need to check the perf deg...

w/ no_sync w/o no_sync
tokens/sec 74267.9 75113.3
memory consumption [GB] 52.6 27.2

crcrpar avatar Mar 22 '24 05:03 crcrpar