swa_gaussian icon indicating copy to clipboard operation
swa_gaussian copied to clipboard

Cannot find key 'n_models'

Open LukasMosser opened this issue 4 years ago • 1 comments

Hi @wjmaddox!

I've been trying to reproduce the results for the segmentation experiment and have hit an error I cannot seem to fix. I'm using the commands in the readme to train a SWAG model and then to evaluate but I end up with the following error. Any idea what the reason could be?

python eval_ensemble.py --data_path /home/ec2-user/CamVid/ --batch_size 4 --method SWAG --scale=0.5 --loss cross_entropy --N 50 --file ./experiment_swag/checkpoint-1000.pt --save_path ./experiment_swag/output.npz

/home/ec2-user/CamVid/
Preparing model
Loading model ./experiment_swag/checkpoint-1000.pt

Traceback (most recent call last):
  File "eval_ensemble.py", line 146, in <module>
    model.load_state_dict(checkpoint["state_dict"])
  File "/home/ec2-user/swa_gaussian/swag/posteriors/swag.py", line 182, in load_state_dict
    n_models = state_dict["n_models"].item()

KeyError: 'n_models'

LukasMosser avatar Oct 16 '20 07:10 LukasMosser

Hi,

Just a word of caution here -- we never could really get the segmentation code to reproduce the results in the original Tiramisu paper so I don't know what you'll find there :(...

That being said, it looks like the issue is that you're trying to load a model that does not have a "n_models" buffer in the state dict, so how did you train it?

If you're confident that you indeed trained and are attempting to reload a SWAG model, make sure that the n_models buffer in the script is set to what you trained the model with, and add the strict=False flag as in: model.load_state_dict(checkpoint["state_dict"], strict=False)

wjmaddox avatar Oct 20 '20 13:10 wjmaddox