msprior icon indicating copy to clipboard operation
msprior copied to clipboard

Error on export recurrent model (torch.Size shape mismatch)

Open devstermarts opened this issue 2 years ago • 2 comments

Hey @caillonantoine i'm running into the following error on msprior export:

streaming mode is set to True Traceback (most recent call last): File "/content/miniconda/bin/msprior", line 8, in <module> sys.exit(main()) File "/content/miniconda/lib/python3.9/site-packages/msprior_scripts/main_cli.py", line 28, in main app.run(module.main) File "/content/miniconda/lib/python3.9/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/content/miniconda/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/content/miniconda/lib/python3.9/site-packages/msprior_scripts/export.py", line 16, in main model = ScriptedPrior( File "/content/miniconda/lib/python3.9/site-packages/msprior/scripted.py", line 53, in __init__ model.load_state_dict(ckpt, strict=False) File "/content/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Prior: size mismatch for decoder.net.0._state: copying a param with shape torch.Size([8, 64, 512]) from checkpoint, the shape in current model is torch.Size([8, 1, 512]).

msprior version is 1.1.2 Training has been done with --config recurrent.

devstermarts avatar Jul 05 '23 16:07 devstermarts

Adding some of my speculation here until someone else replies: I think the problem is because of the different batch_size in the training and export scripts. Setting the --batch_size 64 on the export script (since batch size is 64 by default during training) seemed to get rid of the error. Although I'm not sure if this is the right way to export the model. I am also unsure why this error didn't show up with the decoder_only configuration.

snnithya avatar Jul 05 '23 17:07 snnithya

Thanks @snnithya for looking into this. Using --batch_size 64 did the trick on msprior export for now.

devstermarts avatar Jul 06 '23 11:07 devstermarts