physicsnemo
physicsnemo copied to clipboard
Enable saving/loading Module compiled with torch dynamo and dynamically set AMP in CorrDiff Module wrappers
PhysicsNeMo Pull Request
Description
This PR:
- Implements utilities to strip out torch dynamo wrappers when attempting to save a
Modulethat was compiled withmodel = torch.compile(model). - Add log messages when attempting to load a Module checkpoint into an already compiled Module.
- Perform similar operations for the associated optimizer, when attempting to save the optimizer to a checkpoint.
- Re-orders the loading/compilation logic in CorrDiff
train.py. - Add more explicit and clear log messages when patching is automatically disabled in CorrDiff training.
- Refactor the
amp_modeof the CorrDiff wrappersUNetandEDMPrecondSuperResolution, such that amp can be dynamically set/unset after loading checkpoints.
Closes #883. Closes #874.
Checklist
- [x] I am familiar with the Contributing Guidelines.
- [x] New or existing tests cover these changes.
- [x] The documentation is up to date with these changes.
- [x] The CHANGELOG.md is up to date with these changes.
- [x] An issue is linked to this pull request.
Dependencies
None.
/blossom-ci
/blossom-ci
/blossom-ci
/blossom-ci