VLN-CE icon indicating copy to clipboard operation
VLN-CE copied to clipboard

Error(s) in loading state_dict for CMAPolicy:

Open volkancirik opened this issue 3 years ago • 3 comments

Hello,

I'm downloaded the pretrained model and try to run the test_set_inference.yaml. I receive this error.

Unexpected`` key(s) in state_dict: "critic.fc.weight", "critic.fc.bias"

Is there a mismatch between the code and the pretrained model? If so, could you update the model link? Or do you have other suggestions? Thanks in advance!

volkancirik avatar Jun 23 '21 17:06 volkancirik

The pretrained models load correctly for me on the master branch. For the rxr-habitat-challenge branch, the baseline models inherit from a base policy class that does not have critic.* layers. These layers were never used during training/evaluation of the baselines and can be safely ignored:

https://github.com/jacobkrantz/VLN-CE/blob/52a478259ae41068eecb751a0cc49307384e460b/vlnce_baselines/common/base_il_trainer.py#L76

can be changed to

self.policy.load_state_dict(ckpt_dict["state_dict"], strict=False)

jacobkrantz avatar Jun 28 '21 19:06 jacobkrantz

Hi @jacobkrantz! Thanks for opensourcing your code. I follow your instructions in README.md but still have the same error as @volkancirik described above.

@volkancirik have you figured out how to fix it?

UPD: I set self.policy.load_state_dict(ckpt_dict["state_dict"], strict=False) and then faced with the RuntimeError: 'lengths' argument should be a 1D CPU int64 tensor, but got 1D cuda:0. To fix the RuntimeError I followed advice here https://github.com/pytorch/pytorch/issues/43227 and converted lengths to cpu lengths.cpu() (in torch/nn/utils/rnn.py file).

After that, I finally evaluated CMA_PM_DA_Aug.pth checkpoint without errors

run python run.py --exp-config vlnce_baselines/config/paper_configs/cma_pm_da_aug_tune.yaml --run-type eval

and got the following results

Episodes evaluated: 1839
Average episode distance_to_goal: 7.610544
Average episode success: 0.287113
Average episode spl: 0.265512
Average episode ndtw: 0.496594
Average episode path_length: 8.303312
Average episode oracle_success: 0.356172
Average episode steps_taken: 88.323545

rpartsey avatar Sep 24 '21 08:09 rpartsey

Do you have an idea why the performance is much lower than the result listed here?

cshizhe avatar Dec 09 '21 08:12 cshizhe