imitation icon indicating copy to clipboard operation
imitation copied to clipboard

Observation shape mismatch

Open BryanZ666 opened this issue 2 years ago • 6 comments

Bug description

I try to rollout the trajectories by PPO/A2C for task "BreakoutNoFrameskip-v4". An error occurred "ValueError: Observation spaces do not match". I think the Observation got transposed, (4,84,84) !=(84,84,4).

Steps to reproduce

from stable_baselines3 import PPO,A2C
from stable_baselines3.ppo import MlpPolicy,CnnPolicy
from stable_baselines3.common.evaluation import evaluate_policy
from imitation.data import rollout
from imitation.data.wrappers import RolloutInfoWrapper
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

import argparse

import gym
import torch
import numpy as np

if __name__ == "__main__":
    seed = 0
    env = make_atari_env("BeamRiderNoFrameskip-v4", n_envs=1, seed=seed)
    env = VecFrameStack(env, n_stack=4)

    expert = A2C(policy=CnnPolicy,env=env,seed=seed)
    expert.learn(25000)
    expert.save('./expert')

    expert.load('./expert')

    reward, _ = evaluate_policy(expert, env, 10)
    print("Expert Performance",reward)

    rollouts = rollout.rollout(
        expert,
        env,
        rollout.make_sample_until(min_timesteps=None, min_episodes=2),
    )
    transitions = rollout.flatten_trajectories(rollouts)

Environment

  • Operating system and version: RHEL7.9
  • Python version: Python 3.9.1
  • Output of pip freeze --all:

$ pip3.9 freeze --all absl-py==1.2.0 ale-py==0.7.4 astor==0.8.1 astunparse==1.6.3 AutoROM==0.4.2 AutoROM.accept-rom-license==0.4.2 # Editable install with no version control (baselines==0.1.5) -e /data/brzheng/Project/Explainable IL/baselines bleach==1.5.0 cachetools==5.2.0 certifi==2022.6.15 cffi==1.15.1 chai-sacred==0.8.3 charset-normalizer==2.1.1 click==8.1.3 cloudpickle==1.2.2 colorama==0.4.5 cycler==0.11.0 Cython==0.29.32 dill==0.3.5.1 docopt==0.6.2 flatbuffers==1.12 fonttools==4.37.1 future==0.18.2 gast==0.4.0 gitdb==4.0.9 GitPython==3.1.27 glfw==2.5.4 google-auth==2.11.0 google-auth-oauthlib==0.4.6 google-pasta==0.2.0 grpcio==1.34.1 gym==0.21.0 gym-notices==0.0.8 h5py==3.1.0 html5lib==0.9999999 idna==3.3 imageio==2.21.2 imitation==0.3.1 importlib-metadata==4.12.0 importlib-resources==5.9.0 joblib==1.1.0 jsonpickle==2.2.0 keras==2.9.0 keras-nightly==2.5.0.dev2021032900 Keras-Preprocessing==1.1.2 kiwisolver==1.4.4 libclang==14.0.6 lockfile==0.12.2 Markdown==3.4.1 MarkupSafe==2.1.1 matplotlib==3.5.3 mpi4py==3.1.3 mujoco-py==1.50.1.68 munch==2.5.0 numpy==1.23.2 oauthlib==3.2.0 opencv-python==4.6.0.66 opt-einsum==3.3.0 packaging==21.3 pandas==1.4.3 patchelf==0.15.0.0 Pillow==9.2.0 pip==22.2.2 progressbar2==4.0.0 protobuf==3.19.4 py-cpuinfo==8.0.0 pyasn1==0.4.8 pyasn1-modules==0.2.8 pycparser==2.21 pyglet==1.3.2 pyparsing==3.0.9 python-dateutil==2.8.2 python-utils==3.3.3 pytz==2022.2.1 pyzmq==23.2.1 requests==2.28.1 requests-oauthlib==1.3.1 rsa==4.9 scikit-learn==1.1.2 scipy==1.9.1 seals==0.1.2 setuptools==49.2.1 six==1.15.0 smmap==5.0.0 stable-baselines3 @ git+https://github.com/DLR-RM/stable-baselines3@a7f30b04e3285b62ed72ed3a7183972c03358681 tensorboard==2.9.1 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 tensorflow==2.9.1 tensorflow-estimator==2.9.0 tensorflow-io-gcs-filesystem==0.26.0 tensorflow-tensorboard==1.5.1 termcolor==1.1.0 threadpoolctl==3.1.0 torch==1.11.0 tqdm==4.64.0 typing-extensions==3.7.4.3 urllib3==1.26.12 Werkzeug==2.2.2 wheel==0.37.1 wrapt==1.12.1 zipp==3.8.1 zmq==0.0.0

BryanZ666 avatar Sep 01 '22 09:09 BryanZ666

Indeed - StableBaselines 3 internally transposes the env to be channels-first for training, but rollout.rollout is using the env that you passed in. The solution is to instead call

rollouts = rollout.rollout(
        expert,
        expert.get_env(),
        rollout.make_sample_until(min_timesteps=None, min_episodes=2),
    )

Then you get a different error that I don't really understand - unwrap_traj ends up calling the 'rollout' key of some object that doesn't have that key.

At any rate, we should probably document this issue of rolling out on the wrong env somewhere.

dfilan avatar Sep 01 '22 17:09 dfilan

Related: https://github.com/HumanCompatibleAI/imitation/issues/486, https://github.com/HumanCompatibleAI/imitation/pull/519

dfilan avatar Sep 01 '22 17:09 dfilan

Thanks for your comments, is there anything I could try to overcome this problem? what if I set the unwrap to False to bypass the unwrap_traj function? btw, I simply changed the expert to be PPO but the agent seems could not learn from the environment, what could be the problem?

BryanZ666 avatar Sep 04 '22 07:09 BryanZ666

Then you get a different error that I don't really understand - unwrap_traj ends up calling the 'rollout' key of some object that doesn't have that key.

https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/data/wrappers.py#L160 adds the "rollout" key. I do not know why that would be getting lost in algorithm.get_env() -- I guess one of the wrappers SB3 applise might eliminate the info dict, though that seems counterintuitive.

Setting unwrap to False is probably a viable workaround. It does mean you'll lose the last observation (VecEnv discards it as it auto-resets), but that might be OK for your use case.

AdamGleave avatar Sep 05 '22 18:09 AdamGleave

Yeah, now it can rollouts trajectories! But the agent could learn meaningful policy using the default BC parameter.

bc_trainer = bc.BC(
        observation_space=env.observation_space,
        action_space=env.action_space,
        demonstrations=transitions,)
bc_trainer.train(n_epochs=300)

is there anything I could try to improve it?

BryanZ666 avatar Sep 07 '22 04:09 BryanZ666

To sanity check, the transitions you're training BC on are transposed correctly relative to the environment you're giving BC? I imagine it would error out if not, but if it was treating channel dimension as batch dimension that would certainly prevent it from learning...

Otherwise, I don't have much to advise other than "tune hyperparameters". Behavioral cloning is a pretty weak imitation learning algorithm anyway, so it's common for it to require methods like data augmentation. Something like GAIL tends to be a lot more robust.

AdamGleave avatar Sep 07 '22 06:09 AdamGleave

Closing due to inactivity.

AdamGleave avatar Oct 22 '22 01:10 AdamGleave