lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

[FSDP] Support gradient clipping by norm

Open carmocca opened this issue 1 year ago • 1 comments

🚀 Feature

Pitch

Port https://github.com/pytorch/pytorch/blob/c4a157086482899f0640d03292e5d2c9a6a3db68/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1069-L1194 to work with Thunder's FSDP.

This could be importable through from thunder.distributed.utils import clip_grad_norm_. We could also move FSDP into thunder.distributed.fsdp and put this alongside it (from thunder.distributed.fsdp import clip_grad_norm_). Bikeshedding welcome.

PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html, https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_

cc @carmocca @awaelchli @crcrpar

carmocca avatar Mar 05 '24 11:03 carmocca

When we support and test compiling fwd-bwd-step together, we might want to reimplement this as a transform. But for the current pattern used where gradient clipping happens outside of the trace, we can simply write an ad-hoc function that the user calls.

carmocca avatar Mar 05 '24 16:03 carmocca