modulus
modulus copied to clipboard
🐛[BUG]: CorrDiff - Double nested `model_kwargs` passed to U-Net constructor
Version
0.7.0a
On which installation method(s) does this occur?
Docker, Source
Describe the issue
Analysis: The error comes because model_kwargs
are double nested when passed to the U-Net constructor: {'model_kwargs': {'embedding_type': 'zero', 'encoder_type': 'standard', 'decoder_type': 'standard', 'channel_mult_noise': 1, 'resample_filter': [1, 1], 'model_channels': 128, 'channel_mult': [1, 2, 2, 2, 2], 'attn_resolutions': [28], 'dropout': 0.13}}
This could be a problem with how the metadata for the regression model are saved on the disk.
Proposed fix:
- Pass
model_kwargs["model_kwargs"]
instead ofmodel_kwargs
- Investigate and change how metadata for regression model are saved.
Minimum reproducible example
python3 train.py --config-name=config_train_diffusion.yaml
config:
arch: ddpmpp-cwb
precond: edmv1
task: diffusion
Relevant log output
Traceback (most recent call last):
File "/code/modulus/examples/generative/corrdiff/train.py", line 344, in main
training_loop.training_loop(
File "/code/modulus/examples/generative/corrdiff/training/training_loop.py", line 176, in training_loop <--------------
net_reg = Module.from_checkpoint(regression_checkpoint_path)
File "/code/modulus/modulus/models/module.py", line 357, in from_checkpoint
model = cls.instantiate(args)
File "/code/modulus/modulus/models/module.py", line 175, in instantiate
return _cls(**arg_dict["__args__"])
File "/code/modulus/modulus/models/diffusion/unet.py", line 111, in __init__ <--------------
self.model = model_class(
File "/code/modulus/modulus/models/module.py", line 65, in __new__
bound_args = sig.bind_partial(
File "/usr/lib/python3.10/inspect.py", line 3193, in bind_partial
return self._bind(args, kwargs, partial=True)
File "/usr/lib/python3.10/inspect.py", line 3175, in _bind
raise TypeError(
TypeError: got an unexpected keyword argument 'model_kwargs'
Environment details
python version: 3.10
modulus commit: `c07fa25321c48a1d71efca12b67d056adbca8bd4`