ManiSkill
ManiSkill copied to clipboard
[Question/Bug] Size Mismatch for actor_mean Runtime Error
I faced this issue before and was able to fix it, but for some reason I'm facing the same error again and I'm unsure why.
I previously ran python ppo.py --no-capture-video --env-id Bimanual_Allegro_Cube and would then evaluate the final checkpoint. I would then run into this error (which I am now getting again):
(ms_dev) creativenick@creativenick:~/Desktop/SimToReal/bimanual-sapien$ python ppo.py --evaluate --no-capture-video --checkpoint runs/Bimanual_Allegro_Cube__ppo__1__1717740309/final_ckpt.pt
/home/creativenick/anaconda3/envs/ms_dev/lib/python3.9/site-packages/tyro/_fields.py:343: UserWarning: The field wandb_entity is annotated with type <class 'str'>, but the default value None has type <class 'NoneType'>. We'll try to handle this gracefully, but it may cause unexpected behavior.
warnings.warn(
/home/creativenick/anaconda3/envs/ms_dev/lib/python3.9/site-packages/tyro/_fields.py:343: UserWarning: The field checkpoint is annotated with type <class 'str'>, but the default value None has type <class 'NoneType'>. We'll try to handle this gracefully, but it may cause unexpected behavior.
warnings.warn(
Running evaluation
/home/creativenick/anaconda3/envs/ms_dev/lib/python3.9/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.single_observation_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_observation_space` for environment variables or `env.get_wrapper_attr('single_observation_space')` that will search the reminding wrappers.
logger.warn(
/home/creativenick/anaconda3/envs/ms_dev/lib/python3.9/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.single_action_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_action_space` for environment variables or `env.get_wrapper_attr('single_action_space')` that will search the reminding wrappers.
logger.warn(
/home/creativenick/anaconda3/envs/ms_dev/lib/python3.9/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.max_episode_steps to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.max_episode_steps` for environment variables or `env.get_wrapper_attr('max_episode_steps')` that will search the reminding wrappers.
logger.warn(
####
args.num_iterations=97 args.num_envs=512 args.num_eval_envs=2
args.minibatch_size=3200 args.batch_size=102400 args.update_epochs=4
####
Traceback (most recent call last):
File "/home/creativenick/Desktop/SimToReal/bimanual-sapien/ppo.py", line 318, in <module>
agent.load_state_dict(torch.load(args.checkpoint))
File "/home/creativenick/anaconda3/envs/ms_dev/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2189, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Agent:
size mismatch for actor_logstd: copying a param with shape torch.Size([1, 44]) from checkpoint, the shape in current model is torch.Size([1, 8]).
size mismatch for critic.0.weight: copying a param with shape torch.Size([256, 98]) from checkpoint, the shape in current model is torch.Size([256, 42]).
size mismatch for actor_mean.0.weight: copying a param with shape torch.Size([256, 98]) from checkpoint, the shape in current model is torch.Size([256, 42]).
size mismatch for actor_mean.6.weight: copying a param with shape torch.Size([44, 256]) from checkpoint, the shape in current model is torch.Size([8, 256]).
size mismatch for actor_mean.6.bias: copying a param with shape torch.Size([44]) from checkpoint, the shape in current model is torch.Size([8]).
I previously fixed the error by:
- Making sure the max episode length in the defined environment is larger than num_steps in the
ppo.pyscript - Changed
num_envs=args.num_envs if not args.evaluate else 1,tonum_envs=args.num_envs,
This then fixed the issue, but I noticed I'm facing this same error again. Here are the links to my cube_env.py environment file and ppo.py file.