nequip icon indicating copy to clipboard operation
nequip copied to clipboard

🐛 [BUG] Cannot restart run with different dataset

Open pablo-unzueta opened this issue 2 years ago • 4 comments

Describe the bug I am trying to restart a training instance using load_model_state or initialize_from_state. I keep receiving an error that scale_by from the state_dict is empty while in the new run it is of size 1: RuntimeError: Error(s) in loading state_dict for RescaleOutput: size mismatch for scale_by: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([1]).

I also tried load_model_state_strict: false but that yielded the same error

To Reproduce Attached are the yaml files I used. I start the training with energy_only.yaml. After training for some time, I want to restart using a different dataset using the restart.yaml file.

Expected behavior Training should resume according to #343 or #297

Environment (please complete the following information):

  • OS: CentOs7
  • Python 3.9.12
  • python environment (commands are given for python interpreter):
    • nequip version 0.5.5
    • e3nn version 0.5.0
    • pytorch version 1.11.0

Additional context Add any other context about the problem here. configs.zip

pablo-unzueta avatar Jun 27 '23 21:06 pablo-unzueta

Hi @pablo-unzueta ,

Thanks for your interest in our code!

I'm not sure why this is happening, but initialize_from_state_strict: false would be the correct option in this case.

You could also try adding global_rescale_scale: 0.0 to your restart config...

Linux-cpp-lisp avatar Jun 28 '23 17:06 Linux-cpp-lisp

Hi @Linux-cpp-lisp

Thanks for you advice! I tried initialize_from_state_strict: false and received the same error: RuntimeError: Error(s) in loading state_dict for RescaleOutput: size mismatch for scale_by: copying a param with shape torch.Size([]) from checkpoint, the shape in current model is torch.Size([1]).

I also tried global_rescale_scale: 0.0 and received the following error:

ValueError: Global energy scaling was very low: 0.0. If dataset values were used, does the dataset contain insufficient variation? Maybe try disabling global scaling with global_scale=None.

pablo-unzueta avatar Jun 28 '23 18:06 pablo-unzueta

I couldn't figure out how to set global_scale=None in the config, but I just set the global_scale_scale: 1.1e-6 so it wouldn't raise the ValueError due to it being lower than the threshold. Does this seem ok?

pablo-unzueta avatar Jun 28 '23 18:06 pablo-unzueta

Yes, in princple if you just set it to some number, it will get overriden by the loaded state dict, but I'm still not totally sure why this is happening at all.

If you do this, does it pass sanity checks? Like is the starting validation and training loss the same as before if you restart with the same dataset?

Linux-cpp-lisp avatar Jun 28 '23 21:06 Linux-cpp-lisp