physicsnemo
physicsnemo copied to clipboard
🐛[BUG]: Training CorrDiff with torch.compile does not produce mdlus checkpoints
Version
main
On which installation method(s) does this occur?
No response
Describe the issue
When training CorrDiff with training.perf.torch_compile=True, the training script only produces checkpoints of the following form:
checkpoint.0.<step>.pt
OptimizedModule.0.<step>.pt # compiled model
These cannot be loaded via Module.from_checkpoint in generate.py. I believe we want to pass something like model._orig_mod to save_checkpoint instead when using torch.compile, so we get proper checkpoints like EDMPrecondSuperResolution.0.<step>.mdlus.
Minimum reproducible example
Relevant log output
Environment details