physicsnemo icon indicating copy to clipboard operation
physicsnemo copied to clipboard

Enable saving/loading Module compiled with torch dynamo and dynamically set AMP in CorrDiff Module wrappers

Open CharlelieLrt opened this issue 7 months ago • 1 comments

PhysicsNeMo Pull Request

Description

This PR:

  • Implements utilities to strip out torch dynamo wrappers when attempting to save a Module that was compiled with model = torch.compile(model).
  • Add log messages when attempting to load a Module checkpoint into an already compiled Module.
  • Perform similar operations for the associated optimizer, when attempting to save the optimizer to a checkpoint.
  • Re-orders the loading/compilation logic in CorrDiff train.py.
  • Add more explicit and clear log messages when patching is automatically disabled in CorrDiff training.
  • Refactor the amp_mode of the CorrDiff wrappers UNet and EDMPrecondSuperResolution, such that amp can be dynamically set/unset after loading checkpoints.

Closes #883. Closes #874.

Checklist

  • [x] I am familiar with the Contributing Guidelines.
  • [x] New or existing tests cover these changes.
  • [x] The documentation is up to date with these changes.
  • [x] The CHANGELOG.md is up to date with these changes.
  • [x] An issue is linked to this pull request.

Dependencies

None.

CharlelieLrt avatar May 10 '25 03:05 CharlelieLrt

/blossom-ci

CharlelieLrt avatar May 10 '25 03:05 CharlelieLrt

/blossom-ci

CharlelieLrt avatar May 15 '25 02:05 CharlelieLrt

/blossom-ci

CharlelieLrt avatar May 15 '25 19:05 CharlelieLrt

/blossom-ci

CharlelieLrt avatar May 16 '25 22:05 CharlelieLrt