[FSDP] Support gradient clipping by norm
🚀 Feature
Pitch
Port https://github.com/pytorch/pytorch/blob/c4a157086482899f0640d03292e5d2c9a6a3db68/torch/distributed/fsdp/fully_sharded_data_parallel.py#L1069-L1194 to work with Thunder's FSDP.
This could be importable through from thunder.distributed.utils import clip_grad_norm_.
We could also move FSDP into thunder.distributed.fsdp and put this alongside it (from thunder.distributed.fsdp import clip_grad_norm_). Bikeshedding welcome.
PyTorch docs: https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html, https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
cc @carmocca @awaelchli @crcrpar
When we support and test compiling fwd-bwd-step together, we might want to reimplement this as a transform. But for the current pattern used where gradient clipping happens outside of the trace, we can simply write an ad-hoc function that the user calls.