stylegan2-pytorch icon indicating copy to clipboard operation
stylegan2-pytorch copied to clipboard

Unable to load from checkpoint after migrating model to different machine

Open moinedgylabs opened this issue 3 years ago • 6 comments

Hi,

The training was running on with a small GPU so I saved the latest model, 450th epoch (model_450.pt file) and moved it to another machine. I placed the saved model at models/default/model_450.pt in new machine. I have same version of stylegan (1.8.1) on both machines. Now when I run the command with same, it gives me the following error:

continuing from previous epoch - 450
loading from version 1.8.1
unable to load save model. please try downgrading the package to the version specified by the saved model
Traceback (most recent call last):
  File "/opt/conda/bin/stylegan2_pytorch", line 8, in <module>
    sys.exit(main())
  File "/opt/conda/lib/python3.7/site-packages/stylegan2_pytorch/cli.py", line 187, in main
    fire.Fire(train_from_folder)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 471, in _Fire
    target=component.__name__)
  File "/opt/conda/lib/python3.7/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/stylegan2_pytorch/cli.py", line 178, in train_from_folder
    run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed)
  File "/opt/conda/lib/python3.7/site-packages/stylegan2_pytorch/cli.py", line 52, in run_training
    model.load(load_from)
  File "/opt/conda/lib/python3.7/site-packages/stylegan2_pytorch/stylegan2_pytorch.py", line 1394, in load
    raise e
  File "/opt/conda/lib/python3.7/site-packages/stylegan2_pytorch/stylegan2_pytorch.py", line 1391, in load
    self.GAN.load_state_dict(load_data['GAN'])
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1224, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for StyleGAN2:
        size mismatch for G.blocks.3.to_noise1.weight: copying a param with shape torch.Size([256, 1]) from checkpoint, the shape in current model is torch.Size([512, 1]).
        size mismatch for G.blocks.3.to_noise1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]).
        size mismatch for G.blocks.3.conv1.weight: copying a param with shape torch.Size([256, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([512, 512, 3, 3]).

Would appreciate any help from anyone. Really do not want to train from the scratch.

moinedgylabs avatar May 12 '21 17:05 moinedgylabs

@moinedgylabs Hi Moin, you'll have to also bring over the config file (that exists alongside the model), with the suffix .config.json

lucidrains avatar May 13 '21 15:05 lucidrains

@lucidrains thanks for the reply. Unfortunately, I do not see any .config.json file in the models/default directory. All I see are .pt files. Do I need to look somewhere else for it? Any way I can regenerate this file on the new machine?

moinedgylabs avatar May 13 '21 16:05 moinedgylabs

{"image_size": 128, "network_capacity": 16, "lr_mlp": 0.1, "transparent": false, "fq_layers": [], "fq_dict_size": 256, "attn_layers": [], "no_const": false}

Its the file in my case, I think you should just modify it a little bit

MartinKing01 avatar May 13 '21 16:05 MartinKing01

Thanks, @MartinKing01. I'll give it a try.

moinedgylabs avatar May 13 '21 16:05 moinedgylabs

Unfortunately, even with the above config and even generating config from a fresh install, I still get the same error. Maybe there's something in the model that is missing, it got corrupted or something...

moinedgylabs avatar May 14 '21 11:05 moinedgylabs

I get the same error, but unlike the OP I didn't move to a different machine. In my case it is a version issue (trained on 1.8.1, and and now I can't reload the model even with a downgraded version of the code). Did you see this issue? https://github.com/lucidrains/stylegan2-pytorch/issues/237

RobertRankinTR avatar Jun 21 '21 16:06 RobertRankinTR