[BUG] Actor updates too frequently in SAC with Policy Update Delay
Describe the bug
The policy update delay for SAC implementations can cause unexpected behaviour, when num_envs*rollout_length and policy_update_delay have common factors. If these hyperparameters share factors it will cause t % policy_update_delay = 0 at every update step rather than every policy_update_delay steps as desired, this happens because we step using t+=num_envs*rollout_length .
To Reproduce
Steps to reproduce the behavior:
- Run a SAC implementation (e.g., masac) with default hyperparameters.
- Add debug statements, such as
jax.debug.print("Q Update")inupdate_qandjax.debug.print("Actor Update")inupdate_actor_and_alpha. - Let the program run for a while, then observe the printed output or track the gradient update steps for Q-functions and Actor updates.
Expected behavior
We'd expect "Q Update" be printed policy_update_delay number of times before "Actor Update" is printed policy_update_delay time e.g, if policy_update_delay = 4:
Q Update
Q Update
Q Update
Q Update
Actor Update
Actor Update
Actor Update
Actor Update
Q Update
...
But with the default hyperameter setting where num_envs*rollout_length and policy_update_delay share factors we actually observe:
Q Update
Actor Update
Actor Update
Actor Update
Actor Update
Q Update
Actor Update
Actor Update
Actor Update
Actor Update
...
And if we keep track of the gradient updates you will notice that the Actor updated 4 times (policy_update_delay=4) as often as the Q networks, which I don't think is the intended behavior.
Possible Solution
Use a separate counter that specifically count updates (e.g., update_t rather than t) and incremented by +1 every update step. And use this in the train function e.g.:
params, opt_states, act_loss_info = lax.cond(
update_t % cfg.system.policy_update_delay == 0, # TD 3 Delayed update support
update_actor_and_alpha,
# just return same params and opt_states and 0 for losses
lambda params, opt_states, *_: (
params,
opt_states,
{"actor_loss": 0.0, "alpha_loss": 0.0},
),
params,
opt_states,
data,
actor_key,
)
Hi @LeoHink great find, this is indeed a bug! Would you be able to put up a PR for this, your solution looks good to me :smile: