stable-baselines3-contrib icon indicating copy to clipboard operation
stable-baselines3-contrib copied to clipboard

[Bug]: MaskablePPO Inaccurate update counting when target_kl early exists

Open Sean-Fuhrman opened this issue 8 months ago • 2 comments

🐛 Bug

When MaskablePPO early exits due to target_kl, n_updates is still updated by 'self.n_epochs' instead being incremented only on successful epochs. Therefore if it early exits at epoch 5/10, n_updates will be updated by 10 when it should be updated by 5.

To fix: Line 413 of ppo_mask.py self._n_updates += self.n_epochs should be changed to self._n_updates += 1` and be moved to Line 409 inside the loop. To match normal PPO.

To Reproduce

from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.envs import InvalidActionEnvDiscrete

env = InvalidActionEnvDiscrete(dim=10, n_invalid_actions=3)

model = MaskablePPO(
    policy=MaskableActorCriticPolicy,
    env=env,
    verbose=1,
    target_kl=0.0003,   #set low to ensure early stop        
)

# 4) Train
model.learn(total_timesteps=100_000)

Relevant log output / Error message


System Info

No response

Checklist

Sean-Fuhrman avatar Apr 24 '25 19:04 Sean-Fuhrman

Hello, could you do a PR to fix this issue?

araffin avatar Apr 25 '25 05:04 araffin

Hi, I've tested @Sean-Fuhrman 's fix and it successfully resolves this issue.

I also discovered that the same n_updates counting bug exists in RecurrentPPO. Here's a minimal reproduction:

from sb3_contrib import RecurrentPPO
from sb3_contrib.ppo_recurrent.policies import RecurrentActorCriticCnnPolicy
from sb3_contrib.common.envs import InvalidActionEnvDiscrete

env = InvalidActionEnvDiscrete(dim=10, n_invalid_actions=3)

# PPO Recurrent with n_updates bug
model = RecurrentPPO(
    policy="MlpLstmPolicy",
    env=env,
    verbose=1,
    target_kl=0.003,   #set low to ensure early stop        
)

# Train for just a bit to see the bug
model.learn(total_timesteps=100_000)

Before the fix (incorrect behavior):

Image

n_updates increments by full n_epochs despite early stopping

After the fix (correct behavior):

Image

n_updates now correctly increments by actual epochs completed

I've submitted a PR (#313) that applies the same fix to both MaskablePPO and RecurrentPPO. The other algorithms in sb3-contrib don't appear to have this issue.

Thanks @Sean-Fuhrman for identifying the root cause!

alektebel avatar Nov 16 '25 22:11 alektebel