stable-baselines3-contrib icon indicating copy to clipboard operation
stable-baselines3-contrib copied to clipboard

[Feature Request] VecMaskWrapper for MaskablePPO

Open CAI23sbP opened this issue 10 months ago • 2 comments

🚀 Feature

From the issue.

@araffin ! I am planning to use a Masknet (a custom neural network for masking invalid actions) that requires batch processing

Here is my test code. Environment name is 'CartPole-v1'

My library version is 'sb3_contrib==2.1.0 , stable-baselines==2.1.0'

  1. modify a original code
def get_action_masks(env: GymEnv) -> np.ndarray:
    """
    Checks whether gym env exposes a method returning invalid action masks
    :param env: the Gym environment to get masks from
    :return: A numpy array of the masks
    """

    if isinstance(env, VecEnv):
        return env.get_attr(EXPECTED_METHOD_NAME)
    else:
        return getattr(env, EXPECTED_METHOD_NAME)
    
    
def is_masking_supported(env: GymEnv) -> bool:
    """
    Checks whether gym env exposes a method returning invalid action masks

    :param env: the Gym environment to check
    :return: True if the method is found, False otherwise
    """
    if isinstance(env, VecEnv):
        try:
            env.get_attr(EXPECTED_METHOD_NAME)
            
            return True
        except AttributeError:
            return False
    else:
        return hasattr(env, EXPECTED_METHOD_NAME)

  1. Add a VecMaskWrapper : masking a invalid action
from typing import List
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper
import numpy as np 

EXPECTED_METHOD_NAME = 'action_masks'
class VecMaskWrapper(VecEnvWrapper):
    def __init__(
        self,
        venv: VecEnv,
    ):
        VecEnvWrapper.__init__(self, venv)
        temp_env = venv
        actions = temp_env.action_space.n
        self.num_envs = temp_env.num_envs
        self.possible_actions = np.arange(actions)
        self.all_valid_mask = np.ones((self.num_envs, actions)).astype(np.bool_)

    def reset(self) -> VecEnvObs:
        self.observations = self.venv.reset()
        return self.observations

    def action_masks(self) -> List[bool]:
        """ 
            https://www.gymlibrary.dev/environments/classic_control/cart_pole/
            1. The pole angle can be observed between (-.418, .418) radians (or ±24°), 
            but the episode terminates if the pole angle is not in the range (-.2095, .2095) (or ±12°)

        """
        masks = np.ones_like(self.all_valid_mask, dtype=np.bool_)
        condition_1 = np.where(self.observations[:, 2] <= -0.05)[0] # left terminate
        condition_2 = np.where(self.observations[:, 2] >= 0.05)[0] # right terminate
        masks[condition_1, 1] = False 
        masks[condition_2, 0] = False 
        return masks
        
            
    def get_attr(self, attr_name, indices=None):
        if attr_name == EXPECTED_METHOD_NAME:
            return self.action_masks()
        else:
            return super().get_attr(attr_name, indices)

    def step_wait(self) -> VecEnvStepReturn:
        self.observations, rews, dones, infos = self.venv.step_wait()
        return self.observations, rews, dones, infos
  1. Example test code


from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib.ppo_mask.ppo_mask import MaskablePPO

if __name__ == "__main__":
    env_id = "CartPole-v1"
    vec_env = make_vec_env(env_id, n_envs= 2) 
    vec_env = VecMaskWrapper(vec_env)
    vec_env.reset()
    model = MaskablePPO("MlpPolicy", vec_env, verbose=1)
    model.learn(total_timesteps=25_000)

    obs = vec_env.reset()
    for _ in range(1000):
        action_masks = get_action_masks(vec_env)
        action, _states = model.predict(obs, action_masks= action_masks)
        obs, rewards, dones, info = vec_env.step(action)
        vec_env.render()

Motivation

No response

Pitch

No response

Alternatives

No response

Additional context

No response

Checklist

  • [x] I have checked that there is no similar issue in the repo
  • [x] If I'm requesting a new feature, I have proposed alternatives

CAI23sbP avatar Feb 20 '25 20:02 CAI23sbP

Original issue: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/issues/68

@CAI23sbP I would simplify by having a hasattr() check in isinstance(env, VecEnv), or a is_vec_env_wrapped(env, VecMaskWrapper)

Could you submit a PR with tests and doc?

araffin avatar Mar 31 '25 10:03 araffin

use a Masknet (a custom neural network for masking invalid actions)

May I ask why you're not using a MaskedPPO which achieves exactly that thing - masking invalid actions?

maxmax1992 avatar Jul 01 '25 10:07 maxmax1992