[Question] Significant Performance Disparity Between Maskable PPO and PPO
❓ Question
I tried to run PPO and Maskable PPO on my custom environment with the same configuration, but I found that Maskable PPO(~5fps) is mush slower than PPO(~140fps).
Here's my configurations:
- environment setup
env.action_space = MultiBinary(339) - reproduction code
config = { "env_name": "my_env_name", "n_envs": 16, "policy_type": "MlpPolicy", "total_timesteps": 100000, } # DummyVecEnv vec_env = make_vec_env(config["env_name"], n_envs=config["n_envs"]) model = MaskablePPO("MlpPolicy", vec_env, n_steps=128, verbose=1") # model = PPO("MlpPolicy", vec_env, n_steps=128, verbose=1") model.learn( total_timesteps=config["total_timesteps"], callback=WandbCallback( gradient_save_freq=100, model_save_path=f"models/{experiment_name}", verbose=2, ), progress_bar=True, )
I also tried to profile my code with py-spy, and I found that MaskablePPO spent many extra time in these lines
while PPO spends much less time in train and most of its time in collect_rollouts just as expected.
I wonder if this extreme decline in training efficiency is a normal situation because of the large action_space or if there are other bugs in the implementation.
Checklist
- [x] I have checked that there is no similar issue in the repo
- [x] I have read the documentation
- [x] If code there is, it is minimal and working
- [x] If code there is, it is formatted using the markdown code blocks for both code and stack traces.
Hello,
I also tried to profile my code with py-spy, and I found that MaskablePPO spent many extra time in these lines
at least the slow down is where it would be expected. I'm a bit surprised by how much slow down it, but the code was never optimized for speed, so there is probably room for improvement.
Related code: https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/55a03328c600df76731c3cfa8ae6099d8d3d273a/sb3_contrib/common/maskable/distributions.py#L245-L252
and
https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/blob/55a03328c600df76731c3cfa8ae6099d8d3d273a/sb3_contrib/common/maskable/distributions.py#L58-L62