pytorch-lightning
pytorch-lightning copied to clipboard
Fix ModelParallelStrategy fails with non-distributed checkpoint.
What does this PR do?
- Adds a CUDA-only integration test that mirrors the reporter’s compiled ModelParallel setup so the
KeyError('model.0.weight')reproduces in CI. - Fixes [ModelParallelStrategy.optimizer_state] so when
torch.compilewraps the module, optimizer states get rekeyed through both the compiled wrapper and the original module before single-file checkpointing, preventing the KeyError. - Documents the fix in the unreleased changelog.
Fixes #21357
Before submitting
- Was this discussed/agreed via a GitHub issue? (not for typos and docs)
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
- Did you make sure to update the documentation with your changes? (if necessary)
- Did you write any new necessary tests? (not for typos and docs)
- [x] Did you verify new and existing tests pass locally with your changes? (CUDA test runs in CI; CPU run skips as expected)
- Did you list all the breaking changes introduced by this pull request?
- [x] Did you update the CHANGELOG? (not for typos, docs, test updates, or minor internal changes/refactors)
Codecov Report
:x: Patch coverage is 92.30769% with 2 lines in your changes missing coverage. Please review.
:white_check_mark: Project coverage is 79%. Comparing base (79ffe50) to head (2e5ae5f).
:white_check_mark: All tests successful. No failed tests found.
:exclamation: There is a different number of reports uploaded between BASE (79ffe50) and HEAD (2e5ae5f). Click for more details.
HEAD has 925 uploads less than BASE
Flag BASE (79ffe50) HEAD (2e5ae5f) cpu 239 30 lightning_fabric 60 0 pytest 120 0 python3.12 71 9 python3.12.7 72 9 lightning 120 15 python3.11 48 6 python3.10 24 3 python 24 3 pytorch2.1 24 6 pytest-full 119 30 pytorch_lightning 59 15 pytorch2.6 12 3 pytorch2.4.1 12 3 pytorch2.3 12 3 pytorch2.2.2 12 3 pytorch2.5.1 12 3 pytorch2.9 12 3 pytorch2.7 12 3 pytorch2.8 11 3
Additional details and impacted files
@@ Coverage Diff @@
## master #21384 +/- ##
=========================================
- Coverage 87% 79% -8%
=========================================
Files 269 266 -3
Lines 23804 23775 -29
=========================================
- Hits 20626 18752 -1874
- Misses 3178 5023 +1845