🐛[BUG]: Modules defined in main script cannot be loaded
Version
1.2.0
On which installation method(s) does this occur?
No response
Describe the issue
I defined a module in a training script, providing a name as an argument to the metadata. This model cannot be loaded from another script even if I make sure to import the train module.
I think I am confused about what provided information is used to register a new module.
Is it the name field physicsnemo.ModelMetaData? This was what I expected.
Is the class name and module? If so, if I refactor the module to a different location, does this then mean the checkpoint is no longer compatible? How should this be handled?
Minimum reproducible example
# train.py
from dataclasses import dataclass
import physicsnemo
import torch.nn as nn
from physicsnemo.launch.utils import load_checkpoint, save_checkpoint
# modified from https://docs.nvidia.com/physicsnemo/latest/physicsnemo/api/models/modules.html#how-to-write-your-own-physicsnemo-model
class UNetExample(physicsnemo.Module):
def __init__(self, in_channels=1, out_channels=1, outc=None):
super().__init__(meta=physicsnemo.ModelMetaData(name="UnetExample"))
self.enc1 = self.conv_block(in_channels, 64)
self.enc2 = self.conv_block(64, 128)
self.dec1 = self.upconv_block(128, 64)
self.final = nn.Conv2d(64, out_channels, kernel_size=1)
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
def upconv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, 2, stride=2),
nn.Conv2d(out_channels, out_channels, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
x1 = self.enc1(x)
x2 = self.enc2(x1)
x = self.dec1(x2)
return self.final(x)
if __name__ == "__main__":
# normal save works
unet = UNetExample()
Loading script
$ cat load_module.py
import physicsnemo
import train # imported so that model can register itself
import sys
physicsnemo.Module.from_checkpoint(sys.argv[1])
Relevant log output
$ python3 train.py
/usr/local/lib/python3.12/dist-packages/physicsnemo/utils/filesystem.py:76: SyntaxWarning: invalid escape sequence '\w'
pattern = re.compile(f"{suffix}[\w-]+(/[\w-]+)?/[\w-]+@[A-Za-z0-9.]+/[\w/](.*)")
/usr/local/lib/python3.12/dist-packages/physicsnemo/launch/logging/launch.py:321: SyntaxWarning: invalid escape sequence '\.'
key = re.sub("[^a-zA-Z0-9\.\-\s\/\_]+", "", key)
$ python3 load_module.py out.mdlus
/usr/local/lib/python3.12/dist-packages/physicsnemo/utils/filesystem.py:76: SyntaxWarning: invalid escape sequence '\w'
pattern = re.compile(f"{suffix}[\w-]+(/[\w-]+)?/[\w-]+@[A-Za-z0-9.]+/[\w/](.*)")
/usr/local/lib/python3.12/dist-packages/physicsnemo/launch/logging/launch.py:321: SyntaxWarning: invalid escape sequence '\.'
key = re.sub("[^a-zA-Z0-9\.\-\s\/\_]+", "", key)
Model {'__name__': 'UNetExample', '__module__': '__main__', '__args__': {'in_channels': 1, 'out_channels': 1, 'outc': None}}
Traceback (most recent call last):
File "/lustre/fs1/portfolios/coreai/projects/coreai_climate_earth2/nbrenowitz/repos/edm-chaos/load_module.py", line 5, in <module>
physicsnemo.Module.from_checkpoint(sys.argv[1])
File "/usr/local/lib/python3.12/dist-packages/physicsnemo/models/module.py", line 517, in from_checkpoint
model = Module.instantiate(args)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/physicsnemo/models/module.py", line 278, in instantiate
return _cls(**arg_dict["__args__"])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/python3.12/dist-packages/physicsnemo/models/module.py", line 74, in __new__
bound_args = sig.bind_partial(
^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 3249, in bind_partial
return self._bind(args, kwargs, partial=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.12/inspect.py", line 3231, in _bind
raise TypeError(
TypeError: got an unexpected keyword argument 'in_channels'
Environment details
Is it the name field physicsnemo.ModelMetaData? This was what I expected.
AFAIK the name does not matter that much, it's only used to infer a default filename when saving/loading checkpoints (as well as some other logging functionalities). For example in your snippet above, you could set name="FooBar" as a name in the module's metadata, and unet.save() would produce a FooBar.mdlus checkpoint, but the class of the checkpoint is still UNetExample .
Is the class name and module? If so, if I refactor the module to a different location, does this then mean the checkpoint is no longer compatible? How should this be handled?
Yes, the class.__name__ is used to instantiate and register the Module. More precisely, when saving/loading, two things are happening:
- There is a mechanism to infer the class of the instance being saved.
- The path of the class definition (accessed through
class.__module__) is saved to the checkpoint file(when saving), or used to reinstantiate the Module (when loading).
Since this relies on the class.__module__ a refactor such as moving a class to another location might break this.
But in your snippet above, when loading the Module could you try:
train.UNetExample.from_checkpoint(sys.argv[1])
instead of:
physicsnemo.Module.from_checkpoint(sys.argv[1])
I would expect the former to work, since the mechanism that infers the class should resolve it to UNetExample (irrespectively of where you moved this class during a potential refactor), whereas the latter should fallback to the class.__module__ saved in your checkpoint file.
If that doesn't work, please let me know, because that's a bug that we should fix.
Thanks @CharlelieLrt for the explanation of what the Name parameter does. Indeed this works:
train.UNetExample.from_checkpoint(sys.argv[1])
I suppose I can use the entrypoint mechanism if i want to move the class definition to a new file.
It seems there is no "bug", and the issue could be resolved with some description in the docs.
Glad to know that it works. Let me know if you encounter other issues with the entrypoint mechanism.
I'll also write some explanation about this in the docs, as many users just rely on the from_checkpoint method of physicsnemo.Module instead of their specific subclass.