physicsnemo icon indicating copy to clipboard operation
physicsnemo copied to clipboard

🐛[BUG]: Modules defined in main script cannot be loaded

Open nbren12 opened this issue 2 months ago • 4 comments

Version

1.2.0

On which installation method(s) does this occur?

No response

Describe the issue

I defined a module in a training script, providing a name as an argument to the metadata. This model cannot be loaded from another script even if I make sure to import the train module.

I think I am confused about what provided information is used to register a new module.

Is it the name field physicsnemo.ModelMetaData? This was what I expected.

Is the class name and module? If so, if I refactor the module to a different location, does this then mean the checkpoint is no longer compatible? How should this be handled?

Minimum reproducible example

# train.py
from dataclasses import dataclass
import physicsnemo
import torch.nn as nn
from physicsnemo.launch.utils import load_checkpoint, save_checkpoint


# modified from https://docs.nvidia.com/physicsnemo/latest/physicsnemo/api/models/modules.html#how-to-write-your-own-physicsnemo-model


class UNetExample(physicsnemo.Module):
    def __init__(self, in_channels=1, out_channels=1, outc=None):
        super().__init__(meta=physicsnemo.ModelMetaData(name="UnetExample"))

        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)

        self.dec1 = self.upconv_block(128, 64)
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(x1)
        x = self.dec1(x2)
        return self.final(x)



if __name__ == "__main__":
    # normal save works
    unet = UNetExample()


Loading script

$ cat load_module.py
import physicsnemo
import train # imported so that model can register itself
import sys

physicsnemo.Module.from_checkpoint(sys.argv[1])

Relevant log output

$ python3 train.py
/usr/local/lib/python3.12/dist-packages/physicsnemo/utils/filesystem.py:76: SyntaxWarning: invalid escape sequence '\w'
  pattern = re.compile(f"{suffix}[\w-]+(/[\w-]+)?/[\w-]+@[A-Za-z0-9.]+/[\w/](.*)")
/usr/local/lib/python3.12/dist-packages/physicsnemo/launch/logging/launch.py:321: SyntaxWarning: invalid escape sequence '\.'
  key = re.sub("[^a-zA-Z0-9\.\-\s\/\_]+", "", key)
$ python3 load_module.py out.mdlus
/usr/local/lib/python3.12/dist-packages/physicsnemo/utils/filesystem.py:76: SyntaxWarning: invalid escape sequence '\w'
  pattern = re.compile(f"{suffix}[\w-]+(/[\w-]+)?/[\w-]+@[A-Za-z0-9.]+/[\w/](.*)")
/usr/local/lib/python3.12/dist-packages/physicsnemo/launch/logging/launch.py:321: SyntaxWarning: invalid escape sequence '\.'
  key = re.sub("[^a-zA-Z0-9\.\-\s\/\_]+", "", key)
Model {'__name__': 'UNetExample', '__module__': '__main__', '__args__': {'in_channels': 1, 'out_channels': 1, 'outc': None}}
Traceback (most recent call last):
  File "/lustre/fs1/portfolios/coreai/projects/coreai_climate_earth2/nbrenowitz/repos/edm-chaos/load_module.py", line 5, in <module>
    physicsnemo.Module.from_checkpoint(sys.argv[1])
  File "/usr/local/lib/python3.12/dist-packages/physicsnemo/models/module.py", line 517, in from_checkpoint
    model = Module.instantiate(args)
            ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/physicsnemo/models/module.py", line 278, in instantiate
    return _cls(**arg_dict["__args__"])
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/physicsnemo/models/module.py", line 74, in __new__
    bound_args = sig.bind_partial(
                 ^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 3249, in bind_partial
    return self._bind(args, kwargs, partial=True)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/inspect.py", line 3231, in _bind
    raise TypeError(
TypeError: got an unexpected keyword argument 'in_channels'

Environment details


nbren12 avatar Oct 22 '25 20:10 nbren12

Is it the name field physicsnemo.ModelMetaData? This was what I expected.

AFAIK the name does not matter that much, it's only used to infer a default filename when saving/loading checkpoints (as well as some other logging functionalities). For example in your snippet above, you could set name="FooBar" as a name in the module's metadata, and unet.save() would produce a FooBar.mdlus checkpoint, but the class of the checkpoint is still UNetExample .

Is the class name and module? If so, if I refactor the module to a different location, does this then mean the checkpoint is no longer compatible? How should this be handled?

Yes, the class.__name__ is used to instantiate and register the Module. More precisely, when saving/loading, two things are happening:

  1. There is a mechanism to infer the class of the instance being saved.
  2. The path of the class definition (accessed through class.__module__) is saved to the checkpoint file(when saving), or used to reinstantiate the Module (when loading).

Since this relies on the class.__module__ a refactor such as moving a class to another location might break this.

But in your snippet above, when loading the Module could you try:

  • train.UNetExample.from_checkpoint(sys.argv[1])

instead of:

  • physicsnemo.Module.from_checkpoint(sys.argv[1])

I would expect the former to work, since the mechanism that infers the class should resolve it to UNetExample (irrespectively of where you moved this class during a potential refactor), whereas the latter should fallback to the class.__module__ saved in your checkpoint file.

If that doesn't work, please let me know, because that's a bug that we should fix.

CharlelieLrt avatar Oct 22 '25 22:10 CharlelieLrt

Thanks @CharlelieLrt for the explanation of what the Name parameter does. Indeed this works:

train.UNetExample.from_checkpoint(sys.argv[1])

I suppose I can use the entrypoint mechanism if i want to move the class definition to a new file.

nbren12 avatar Oct 23 '25 01:10 nbren12

It seems there is no "bug", and the issue could be resolved with some description in the docs.

nbren12 avatar Oct 23 '25 01:10 nbren12

Glad to know that it works. Let me know if you encounter other issues with the entrypoint mechanism. I'll also write some explanation about this in the docs, as many users just rely on the from_checkpoint method of physicsnemo.Module instead of their specific subclass.

CharlelieLrt avatar Oct 23 '25 02:10 CharlelieLrt