msprior icon indicating copy to clipboard operation
msprior copied to clipboard

Error on export recurrent model (torch.Size shape mismatch)

Open devstermarts opened this issue 1 year ago • 2 comments

Hey @caillonantoine i'm running into the following error on msprior export:

streaming mode is set to True Traceback (most recent call last): File "/content/miniconda/bin/msprior", line 8, in <module> sys.exit(main()) File "/content/miniconda/lib/python3.9/site-packages/msprior_scripts/main_cli.py", line 28, in main app.run(module.main) File "/content/miniconda/lib/python3.9/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/content/miniconda/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/content/miniconda/lib/python3.9/site-packages/msprior_scripts/export.py", line 16, in main model = ScriptedPrior( File "/content/miniconda/lib/python3.9/site-packages/msprior/scripted.py", line 53, in __init__ model.load_state_dict(ckpt, strict=False) File "/content/miniconda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for Prior: size mismatch for decoder.net.0._state: copying a param with shape torch.Size([8, 64, 512]) from checkpoint, the shape in current model is torch.Size([8, 1, 512]).

msprior version is 1.1.2 Training has been done with --config recurrent.

devstermarts avatar Jul 05 '23 16:07 devstermarts