[Feature Request] VecMaskWrapper for MaskablePPO
🚀 Feature
From the issue.
@araffin ! I am planning to use a Masknet (a custom neural network for masking invalid actions) that requires batch processing
Here is my test code. Environment name is 'CartPole-v1'
My library version is 'sb3_contrib==2.1.0 , stable-baselines==2.1.0'
- modify a original code
def get_action_masks(env: GymEnv) -> np.ndarray:
"""
Checks whether gym env exposes a method returning invalid action masks
:param env: the Gym environment to get masks from
:return: A numpy array of the masks
"""
if isinstance(env, VecEnv):
return env.get_attr(EXPECTED_METHOD_NAME)
else:
return getattr(env, EXPECTED_METHOD_NAME)
def is_masking_supported(env: GymEnv) -> bool:
"""
Checks whether gym env exposes a method returning invalid action masks
:param env: the Gym environment to check
:return: True if the method is found, False otherwise
"""
if isinstance(env, VecEnv):
try:
env.get_attr(EXPECTED_METHOD_NAME)
return True
except AttributeError:
return False
else:
return hasattr(env, EXPECTED_METHOD_NAME)
- Add a VecMaskWrapper : masking a invalid action
from typing import List
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
import numpy as np
EXPECTED_METHOD_NAME = 'action_masks'
class VecMaskWrapper(VecEnvWrapper):
def __init__(
self,
venv: VecEnv,
):
VecEnvWrapper.__init__(self, venv)
temp_env = venv
actions = temp_env.action_space.n
self.num_envs = temp_env.num_envs
self.possible_actions = np.arange(actions)
self.all_valid_mask = np.ones((self.num_envs, actions)).astype(np.bool_)
def reset(self) -> VecEnvObs:
self.observations = self.venv.reset()
return self.observations
def action_masks(self) -> List[bool]:
"""
https://www.gymlibrary.dev/environments/classic_control/cart_pole/
1. The pole angle can be observed between (-.418, .418) radians (or ±24°),
but the episode terminates if the pole angle is not in the range (-.2095, .2095) (or ±12°)
"""
masks = np.ones_like(self.all_valid_mask, dtype=np.bool_)
condition_1 = np.where(self.observations[:, 2] <= -0.05)[0] # left terminate
condition_2 = np.where(self.observations[:, 2] >= 0.05)[0] # right terminate
masks[condition_1, 1] = False
masks[condition_2, 0] = False
return masks
def get_attr(self, attr_name, indices=None):
if attr_name == EXPECTED_METHOD_NAME:
return self.action_masks()
else:
return super().get_attr(attr_name, indices)
def step_wait(self) -> VecEnvStepReturn:
self.observations, rews, dones, infos = self.venv.step_wait()
return self.observations, rews, dones, infos
- Example test code
from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO
if __name__ == "__main__":
env_id = "CartPole-v1"
vec_env = make_vec_env(env_id, n_envs= 2)
vec_env = VecMaskWrapper(vec_env)
vec_env.reset()
model = MaskablePPO("MlpPolicy", vec_env, verbose=1)
model.learn(total_timesteps=25_000)
obs = vec_env.reset()
for _ in range(1000):
action_masks = get_action_masks(vec_env)
action, _states = model.predict(obs, action_masks= action_masks)
obs, rewards, dones, info = vec_env.step(action)
vec_env.render()
Motivation
No response
Pitch
No response
Alternatives
No response
Additional context
No response
Checklist
- [x] I have checked that there is no similar issue in the repo
- [x] If I'm requesting a new feature, I have proposed alternatives
Original issue: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/68
@CAI23sbP I would simplify by having a hasattr() check in isinstance(env, VecEnv), or a is_vec_env_wrapped(env, VecMaskWrapper)
Could you submit a PR with tests and doc?
use a Masknet (a custom neural network for masking invalid actions)
May I ask why you're not using a MaskedPPO which achieves exactly that thing - masking invalid actions?