Support `grad_clip_norm_()` for FSDP
What does this PR do?
Adds gradient norm clipping support for FSDP. Tests fine locally.
For fun, here's a research deep dive ChatGPT came up with when comparing norm and value-based gradient clipping: https://chatgpt.com/s/dr_68168a3400988191be64b3c743a4ccf3.
Fixes #19235
Before submitting
- Was this discussed/agreed via a GitHub issue? (not for typos and docs)
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
- Did you make sure to update the documentation with your changes? (if necessary)
- Did you write any new necessary tests? (not for typos and docs)
- [x] Did you verify new and existing tests pass locally with your changes?
- Did you list all the breaking changes introduced by this pull request?
- Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)
PR review
Anyone in the community is welcome to review the PR. Before you start reviewing, make sure you have read the review guidelines. In short, see the following bullet-list:
Reviewer checklist
- [x] Is this pull request ready for review? (if not, please submit in draft mode)
- [x] Check that all items from Before submitting are resolved
- [x] Make sure the title is self-explanatory and the description concisely explains the PR
- [x] Add labels and milestones (and optionally projects) to the PR so it can be classified
📚 Documentation preview 📚: https://pytorch-lightning--20784.org.readthedocs.build/en/20784/
For some reason, readthedocs is raising the following build error (that I don't believe my PR has caused):
File "/tmp/pip-build-env-8dnu3z25/normal/lib/python3.9/site-packages/pbr/packaging.py", line 492, in run
bs_cmd, 'executable', easy_install.sys_executable)
AttributeError: module 'setuptools.command.easy_install' has no attribute 'sys_executable'
[end of output]
note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed building wheel for sphinxcontrib-fulltoc
ERROR: Failed to build installable wheels for some pyproject.toml based projects (sphinxcontrib-fulltoc)
Package Version
--------------- -----------
awscli 1.40.7
botocore 1.38.8
colorama 0.4.6
distlib 0.3.9
docutils 0.19
filelock 3.16.1
jmespath 1.0.1
pip 25.1.1
platformdirs 3.11.0
py-tree 1.0.1
pyasn1 0.6.1
python-dateutil 2.9.0.post0
PyYAML 6.0.2
rsa 4.7.2
s3transfer 0.12.0
setuptools 58.1.0
six 1.17.0
urllib3 1.26.20
virtualenv 20.21.1
[rtd-command-info] start-time: 2025-05-03T21:35:45.185287Z, end-time: 2025-05-03T21:35:45.237650Z, duration: 0, exit-code: 2
bash docs/rtfd-build.sh
+ '[' 20784 == latest -o 20784 == stable ']'
+ export FAST_DOCS_DEV=1
+ FAST_DOCS_DEV=1
++ pwd
+ root=/home/docs/checkouts/readthedocs.org/user_builds/pytorch-lightning/checkouts/20784
+ for pkg in 'fabric' 'pytorch'
+ cd /home/docs/checkouts/readthedocs.org/user_builds/pytorch-lightning/checkouts/20784/docs/source-fabric
++ nproc
+ make html --jobs 2
/bin/sh: 1: sphinx-build: not found
make: *** [Makefile:19: html] Error 127
I'm also not sure which path the following "Code check / mypy (pull_request)" error is coming from:
src/lightning/pytorch/plugins/precision/fsdp.py:89: error: "Tensor" not callable [operator]
This pull request has been automatically marked as stale because it has not had recent activity. It will be closed in 7 days if no further activity occurs. If you need further help see our docs: https://lightning.ai/docs/pytorch/latest/generated/CONTRIBUTING.html#pull-request or ask the assistance of a core contributor here or on Discord. Thank you for your contributions.
Let's check the typing which is now set for PT 2.8
@Borda,
I have no idea why types-pycurl is flagging the following (last) line of this function:
@override
def clip_grad_by_norm(self, module: Optional[Module], optimizer: Optimizer, clip_val: Union[int, float]) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
if module is None:
return
module.clip_grad_norm_(clip_val)
It thinks that module.clip_grad_norm_ can sometimes reference a torch.Tensor object, which in practice will never happen (as the other unit tests show). However, I could add a goofy other if-check, something like if isinstance(module.clip_grad_norm_, Tensor): return, but I'll leave that decision up to you on how to proceed.
I can't seem to identify anything particular to this PR that would be causing the docs checks to fail. In which case, is it alright if we merge this PR as is?