physicsnemo icon indicating copy to clipboard operation
physicsnemo copied to clipboard

🐛[BUG]: Training CorrDiff with torch.compile does not produce mdlus checkpoints

Open swbg opened this issue 7 months ago • 1 comments

Version

main

On which installation method(s) does this occur?

No response

Describe the issue

When training CorrDiff with training.perf.torch_compile=True, the training script only produces checkpoints of the following form:

checkpoint.0.<step>.pt
OptimizedModule.0.<step>.pt  # compiled model

These cannot be loaded via Module.from_checkpoint in generate.py. I believe we want to pass something like model._orig_mod to save_checkpoint instead when using torch.compile, so we get proper checkpoints like EDMPrecondSuperResolution.0.<step>.mdlus.

Minimum reproducible example


Relevant log output


Environment details


swbg avatar May 09 '25 09:05 swbg