[Bug]: MaskablePPO Inaccurate update counting when target_kl early exists
🐛 Bug
When MaskablePPO early exits due to target_kl, n_updates is still updated by 'self.n_epochs' instead being incremented only on successful epochs. Therefore if it early exits at epoch 5/10, n_updates will be updated by 10 when it should be updated by 5.
To fix:
Line 413 of ppo_mask.py self._n_updates += self.n_epochs should be changed to self._n_updates += 1` and be moved to Line 409 inside the loop. To match normal PPO.
To Reproduce
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
env = InvalidActionEnvDiscrete(dim=10, n_invalid_actions=3)
model = MaskablePPO(
policy=MaskableActorCriticPolicy,
env=env,
verbose=1,
target_kl=0.0003, #set low to ensure early stop
)
# 4) Train
model.learn(total_timesteps=100_000)
Relevant log output / Error message
System Info
No response
Checklist
- [x] I have checked that there is no similar issue in the repo
- [x] I have read the documentation
- [x] I have provided a minimal and working example to reproduce the bug
- [x] I've used the markdown code blocks for both code and stack traces.
Hello, could you do a PR to fix this issue?
Hi, I've tested @Sean-Fuhrman 's fix and it successfully resolves this issue.
I also discovered that the same n_updates counting bug exists in RecurrentPPO. Here's a minimal reproduction:
from sb3_contrib import RecurrentPPO
from sb3_contrib.ppo_recurrent.policies import RecurrentActorCriticCnnPolicy
from sb3_contrib.common.envs import InvalidActionEnvDiscrete
env = InvalidActionEnvDiscrete(dim=10, n_invalid_actions=3)
# PPO Recurrent with n_updates bug
model = RecurrentPPO(
policy="MlpLstmPolicy",
env=env,
verbose=1,
target_kl=0.003, #set low to ensure early stop
)
# Train for just a bit to see the bug
model.learn(total_timesteps=100_000)
Before the fix (incorrect behavior):
n_updates increments by full n_epochs despite early stopping
After the fix (correct behavior):
n_updates now correctly increments by actual epochs completed
I've submitted a PR (#313) that applies the same fix to both MaskablePPO and RecurrentPPO. The other algorithms in sb3-contrib don't appear to have this issue.
Thanks @Sean-Fuhrman for identifying the root cause!