imitation
imitation copied to clipboard
Observation shape mismatch
Bug description
I try to rollout the trajectories by PPO/A2C for task "BreakoutNoFrameskip-v4". An error occurred "ValueError: Observation spaces do not match". I think the Observation got transposed, (4,84,84) !=(84,84,4).
Steps to reproduce
from stable_baselines3 import PPO,A2C
from stable_baselines3.ppo import MlpPolicy,CnnPolicy
from stable_baselines3.common.evaluation import evaluate_policy
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
import argparse
import gym
import torch
import numpy as np
if __name__ == "__main__":
seed = 0
env = make_atari_env("BeamRiderNoFrameskip-v4", n_envs=1, seed=seed)
env = VecFrameStack(env, n_stack=4)
expert = A2C(policy=CnnPolicy,env=env,seed=seed)
expert.learn(25000)
expert.save('./expert')
expert.load('./expert')
reward, _ = evaluate_policy(expert, env, 10)
print("Expert Performance",reward)
rollouts = rollout.rollout(
expert,
env,
rollout.make_sample_until(min_timesteps=None, min_episodes=2),
)
transitions = rollout.flatten_trajectories(rollouts)
Environment
- Operating system and version: RHEL7.9
- Python version: Python 3.9.1
- Output of
pip freeze --all
:
$ pip3.9 freeze --all absl-py==1.2.0 ale-py==0.7.4 astor==0.8.1 astunparse==1.6.3 AutoROM==0.4.2 AutoROM.accept-rom-license==0.4.2 # Editable install with no version control (baselines==0.1.5) -e /data/brzheng/Project/Explainable IL/baselines bleach==1.5.0 cachetools==5.2.0 certifi==2022.6.15 cffi==1.15.1 chai-sacred==0.8.3 charset-normalizer==2.1.1 click==8.1.3 cloudpickle==1.2.2 colorama==0.4.5 cycler==0.11.0 Cython==0.29.32 dill==0.3.5.1 docopt==0.6.2 flatbuffers==1.12 fonttools==4.37.1 future==0.18.2 gast==0.4.0 gitdb==4.0.9 GitPython==3.1.27 glfw==2.5.4 google-auth==2.11.0 google-auth-oauthlib==0.4.6 google-pasta==0.2.0 grpcio==1.34.1 gym==0.21.0 gym-notices==0.0.8 h5py==3.1.0 html5lib==0.9999999 idna==3.3 imageio==2.21.2 imitation==0.3.1 importlib-metadata==4.12.0 importlib-resources==5.9.0 joblib==1.1.0 jsonpickle==2.2.0 keras==2.9.0 keras-nightly==2.5.0.dev2021032900 Keras-Preprocessing==1.1.2 kiwisolver==1.4.4 libclang==14.0.6 lockfile==0.12.2 Markdown==3.4.1 MarkupSafe==2.1.1 matplotlib==3.5.3 mpi4py==3.1.3 mujoco-py==1.50.1.68 munch==2.5.0 numpy==1.23.2 oauthlib==3.2.0 opencv-python==4.6.0.66 opt-einsum==3.3.0 packaging==21.3 pandas==1.4.3 patchelf==0.15.0.0 Pillow==9.2.0 pip==22.2.2 progressbar2==4.0.0 protobuf==3.19.4 py-cpuinfo==8.0.0 pyasn1==0.4.8 pyasn1-modules==0.2.8 pycparser==2.21 pyglet==1.3.2 pyparsing==3.0.9 python-dateutil==2.8.2 python-utils==3.3.3 pytz==2022.2.1 pyzmq==23.2.1 requests==2.28.1 requests-oauthlib==1.3.1 rsa==4.9 scikit-learn==1.1.2 scipy==1.9.1 seals==0.1.2 setuptools==49.2.1 six==1.15.0 smmap==5.0.0 stable-baselines3 @ git+https://github.com/DLR-RM/stable-baselines3@a7f30b04e3285b62ed72ed3a7183972c03358681 tensorboard==2.9.1 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 tensorflow==2.9.1 tensorflow-estimator==2.9.0 tensorflow-io-gcs-filesystem==0.26.0 tensorflow-tensorboard==1.5.1 termcolor==1.1.0 threadpoolctl==3.1.0 torch==1.11.0 tqdm==4.64.0 typing-extensions==3.7.4.3 urllib3==1.26.12 Werkzeug==2.2.2 wheel==0.37.1 wrapt==1.12.1 zipp==3.8.1 zmq==0.0.0
Indeed - StableBaselines 3 internally transposes the env to be channels-first for training, but rollout.rollout
is using the env that you passed in. The solution is to instead call
rollouts = rollout.rollout(
expert,
expert.get_env(),
rollout.make_sample_until(min_timesteps=None, min_episodes=2),
)
Then you get a different error that I don't really understand - unwrap_traj
ends up calling the 'rollout'
key of some object that doesn't have that key.
At any rate, we should probably document this issue of rolling out on the wrong env somewhere.
Related: https://github.com/HumanCompatibleAI/imitation/issues/486, https://github.com/HumanCompatibleAI/imitation/pull/519
Thanks for your comments, is there anything I could try to overcome this problem? what if I set the unwrap to False to bypass the unwrap_traj function? btw, I simply changed the expert to be PPO but the agent seems could not learn from the environment, what could be the problem?
Then you get a different error that I don't really understand -
unwrap_traj
ends up calling the'rollout'
key of some object that doesn't have that key.
https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/data/wrappers.py#L160 adds the "rollout" key. I do not know why that would be getting lost in algorithm.get_env()
-- I guess one of the wrappers SB3 applise might eliminate the info dict, though that seems counterintuitive.
Setting unwrap
to False is probably a viable workaround. It does mean you'll lose the last observation (VecEnv
discards it as it auto-resets), but that might be OK for your use case.
Yeah, now it can rollouts trajectories! But the agent could learn meaningful policy using the default BC parameter.
bc_trainer = bc.BC(
observation_space=env.observation_space,
action_space=env.action_space,
demonstrations=transitions,)
bc_trainer.train(n_epochs=300)
is there anything I could try to improve it?
To sanity check, the transitions you're training BC on are transposed correctly relative to the environment you're giving BC? I imagine it would error out if not, but if it was treating channel dimension as batch dimension that would certainly prevent it from learning...
Otherwise, I don't have much to advise other than "tune hyperparameters". Behavioral cloning is a pretty weak imitation learning algorithm anyway, so it's common for it to require methods like data augmentation. Something like GAIL tends to be a lot more robust.
Closing due to inactivity.