stable-baselines3-contrib
stable-baselines3-contrib copied to clipboard
Possible issue with Maskable PPO
Describe the bug I've been trying to troubleshoot why my MPPO training is very slow (50it/s) when PPO on breakout is ~750it/s.
The speed degredation only happens with MPPO when using a SubProcVecEnv. I dug a little deeper and I wonder if the following code might be to blame:
def get_action_masks(env: GymEnv) -> np.ndarray:
"""
Checks whether gym env exposes a method returning invalid action masks
:param env: the Gym environment to get masks from
:return: A numpy array of the masks
"""
if isinstance(env, VecEnv):
return np.stack(env.env_method(EXPECTED_METHOD_NAME))
else:
return getattr(env, EXPECTED_METHOD_NAME)()
I suspect in the case of a SubProcVecEnv calling env_method an extra time is quite a hit on performance. I validated my theory by running a training loop without SubProcVecEnv and speed then rises to 120it/s with that one change.
Does my reasoning make sense and would it be smarter to return the valid action masks as a key in the info dict? I'm gonna try make that change locally and see if that changes things speed wise.