stable-baselines3-contrib icon indicating copy to clipboard operation
stable-baselines3-contrib copied to clipboard

[Question] Significant Performance Disparity Between Maskable PPO and PPO

Open gemelom opened this issue 9 months ago • 6 comments

❓ Question

I tried to run PPO and Maskable PPO on my custom environment with the same configuration, but I found that Maskable PPO(~5fps) is mush slower than PPO(~140fps).

Here's my configurations:

  • environment setup
    env.action_space = MultiBinary(339)
    
  • reproduction code
    config = {
        "env_name": "my_env_name",
        "n_envs": 16,
        "policy_type": "MlpPolicy",
        "total_timesteps": 100000,
    }
    
    # DummyVecEnv
    vec_env = make_vec_env(config["env_name"], n_envs=config["n_envs"])
    
    model = MaskablePPO("MlpPolicy", vec_env, n_steps=128, verbose=1")
    # model = PPO("MlpPolicy", vec_env, n_steps=128, verbose=1")
    
    model.learn(
        total_timesteps=config["total_timesteps"],
        callback=WandbCallback(
            gradient_save_freq=100,
            model_save_path=f"models/{experiment_name}",
            verbose=2,
        ),
        progress_bar=True,
    )
    

I also tried to profile my code with py-spy, and I found that MaskablePPO spent many extra time in these lines

Image

while PPO spends much less time in train and most of its time in collect_rollouts just as expected.

I wonder if this extreme decline in training efficiency is a normal situation because of the large action_space or if there are other bugs in the implementation.

Checklist

gemelom avatar Mar 08 '25 08:03 gemelom

Hello,

I also tried to profile my code with py-spy, and I found that MaskablePPO spent many extra time in these lines

at least the slow down is where it would be expected. I'm a bit surprised by how much slow down it, but the code was never optimized for speed, so there is probably room for improvement.

Related code: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/55a03328c600df76731c3cfa8ae6099d8d3d273a/sb3_contrib/common/maskable/distributions.py#L245-L252

and

https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/55a03328c600df76731c3cfa8ae6099d8d3d273a/sb3_contrib/common/maskable/distributions.py#L58-L62

araffin avatar Mar 10 '25 11:03 araffin