imitation
imitation copied to clipboard
Weird behavior training CNN policies with train_rl
Problem
Not sure I would call this a bug, but it is definitely unintuitive behavior for me. The fundamental issue is that all CNN Policies in SB3 are just ActorCriticPolicies with a single changed default argument (compare). The only thing that makes a CnnPolicy is the NatureCNN feature extractor.
If I want to train a CNN policy in an image-based env this is the relevant part of the config I need to modify:
train:
n_episodes_eval = 50 # Num of episodes for final mean ground truth return
policy_cls = <class 'stable_baselines3.common.policies.ActorCriticCnnPolicy'> # Training
policy_kwargs:
features_extractor_class = <class 'imitation.policies.base.NormalizeFeaturesExtractor'>
features_extractor_kwargs:
normalize_class = <class 'imitation.util.networks.RunningNorm'>
This will silently give me an MlpPolicy, even though it looks like I should have a CnnPolicy from the config.
Now if I know this I might think I can circumvent this by setting policy_kwargs=dict()
so the NatureCNN doesn't get overridden.
This again silently gives me an MlpPolicy due to this config hook.
The only way to actually get a CNN policy is to explicitly set policy_kwargs=dict(features_extractor_class=NatureCNN)
which makes the CnnPolicy obsolete, as this is equivalent to just using ActorCriticPolicy as the policy.
(I say it uses MlpPolicy silently but if you have loglevel set to INFO and remember to look it does print a summary of the net architecture.)
Here is a small script that illustrates these problems. Since there are no image-based environments in seals I use procgen here.
"""Thin wrapper around imitation's train_rl script"""
from imitation.scripts.config.train_rl import train_rl_ex
from imitation.scripts.train_rl import main_console
from stable_baselines3.common.torch_layers import NatureCNN
from stable_baselines3.ppo import CnnPolicy
# doesn't work
@train_rl_ex.named_config
def a():
common = dict(env_name="procgen:procgen-coinrun-v0")
train = dict(
policy_cls=CnnPolicy,
)
locals() # make flake8 happy
# does not work
@train_rl_ex.named_config
def b():
common = dict(env_name="procgen:procgen-coinrun-v0")
train = dict(
policy_cls=CnnPolicy,
policy_kwargs=dict(features_extractor_class=None),
)
locals() # make flake8 happy
# does not work
@train_rl_ex.named_config
def c():
common = dict(env_name="procgen:procgen-coinrun-v0")
train = dict(
policy_cls=CnnPolicy,
policy_kwargs=dict(features_extractor_class={}),
)
locals() # make flake8 happy
# works
@train_rl_ex.named_config
def d():
common = dict(env_name="procgen:procgen-coinrun-v0")
train = dict(
policy_cls=CnnPolicy,
policy_kwargs=dict(features_extractor_class=NatureCNN),
)
locals() # make flake8 happy
# doesn't work
@train_rl_ex.named_config
def e():
common = dict(env_name="procgen:procgen-coinrun-v0")
train = dict(
policy_cls=CnnPolicy,
policy_kwargs=dict(),
)
locals() # make flake8 happy
if __name__ == "__main__": # pragma: no cover
main_console()
Solution
Not sure how to solve this, but I would expect CnnPolicy to directly work.
I gues the fact that the distinguishing feature of a CnnPolicy is just a default argument is a bit weird (or at least it doesn't mesh well with imitation
).
Huh, I did not notice this when playing around with training CNN policies and rewards.
I think either of these two CnnPolicies would mesh better with imitation. Maybe imitation could just provide its own CnnPolicy?
Don't allow changing feature extractor (not super elegant, parameter still needed for compatibility):
class ActorCriticCnnPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Optional[Type[BaseFeaturesExtractor]] = None,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
ortho_init,
use_sde,
log_std_init,
full_std,
sde_net_arch,
use_expln,
squash_output,
NatureCNN,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
More complicated alternative: Maybe a policy like this would be better?
class ActorCriticCnnPolicy(ActorCriticPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None,
activation_fn: Type[nn.Module] = nn.Tanh,
ortho_init: bool = True,
use_sde: bool = False,
log_std_init: float = 0.0,
full_std: bool = True,
sde_net_arch: Optional[List[int]] = None,
use_expln: bool = False,
squash_output: bool = False,
features_extractor_class: Optional[Type[BaseFeaturesExtractor]] = None,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
):
def f(*args, **kwargs):
extractor = features_extractor_class(*args, **kwargs)
cnn = NatureCNN(*args, **kwargs)
return nn.Sequential(extractor, cnn)
features_extractor = NatureCNN if features_extractor_class == None else f
super().__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
ortho_init,
use_sde,
log_std_init,
full_std,
sde_net_arch,
use_expln,
squash_output,
features_extractor,
features_extractor_kwargs,
normalize_images,
optimizer_class,
optimizer_kwargs,
)
Not sure if this is the direction you want to go, though.
Can we just provide a named config to set reasonable defaults for CNN policies?
Yes, that would work!
@PavelCz is this still a live issue? If so, could you maybe summarize what needs to be in the default config and I'll assign it to someone?
@AdamGleave Just tested with the current master. Yes the same issues exists.
python -m imitation.scripts.train_rl with common.env_name=AsteroidsNoFrameskip-v4
gives you a FF policy, as it should.
Using named config a) from above, which would be the most natural to me, still gives you this
INFO - imitation.scripts.common.rl - Policy network summary:
ActorCriticCnnPolicy(
(features_extractor): NormalizeFeaturesExtractor(
(flatten): Flatten(start_dim=1, end_dim=-1)
(normalize): RunningNorm()
)
(mlp_extractor): MlpExtractor(
(shared_net): Sequential()
(policy_net): Sequential(
(0): Linear(in_features=100800, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
)
(value_net): Sequential(
(0): Linear(in_features=100800, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
)
)
(action_net): Linear(in_features=64, out_features=14, bias=True)
(value_net): Linear(in_features=64, out_features=1, bias=True)
)
which doesn't contain any convolutional layers.
d) works and gives
INFO - imitation.scripts.common.rl - Policy network summary:
ActorCriticCnnPolicy(
(features_extractor): NatureCNN(
(cnn): Sequential(
(0): Conv2d(3, 32, kernel_size=(8, 8), stride=(4, 4))
(1): ReLU()
(2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
(3): ReLU()
(4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
(5): ReLU()
(6): Flatten(start_dim=1, end_dim=-1)
)
(linear): Sequential(
(0): Linear(in_features=22528, out_features=512, bias=True)
(1): ReLU()
)
)
(mlp_extractor): MlpExtractor(
(shared_net): Sequential()
(policy_net): Sequential()
(value_net): Sequential()
)
(action_net): Linear(in_features=512, out_features=14, bias=True)
(value_net): Linear(in_features=512, out_features=1, bias=True)
)
So I would propose as default config for cnns:
@train_rl_ex.named_config
def cnn_policy():
train = dict(
policy_cls=CnnPolicy,
policy_kwargs=dict(features_extractor_class=NatureCNN),
)
locals() # make flake8 happy
Fixed in #610