Mava icon indicating copy to clipboard operation
Mava copied to clipboard

[BUG] Actor updates too frequently in SAC with Policy Update Delay

Open LeoHink opened this issue 1 year ago • 1 comments

Describe the bug

The policy update delay for SAC implementations can cause unexpected behaviour, when num_envs*rollout_length and policy_update_delay have common factors. If these hyperparameters share factors it will cause t % policy_update_delay = 0 at every update step rather than every policy_update_delay steps as desired, this happens because we step using t+=num_envs*rollout_length .

To Reproduce

Steps to reproduce the behavior:

  1. Run a SAC implementation (e.g., masac) with default hyperparameters.
  2. Add debug statements, such as jax.debug.print("Q Update") in update_q and jax.debug.print("Actor Update") in update_actor_and_alpha.
  3. Let the program run for a while, then observe the printed output or track the gradient update steps for Q-functions and Actor updates.

Expected behavior

We'd expect "Q Update" be printed policy_update_delay number of times before "Actor Update" is printed policy_update_delay time e.g, if policy_update_delay = 4:

Q Update
Q Update
Q Update
Q Update
Actor Update
Actor Update
Actor Update
Actor Update
Q Update
...

But with the default hyperameter setting where num_envs*rollout_length and policy_update_delay share factors we actually observe:

Q Update
Actor Update
Actor Update
Actor Update
Actor Update
Q Update
Actor Update
Actor Update
Actor Update
Actor Update
...

And if we keep track of the gradient updates you will notice that the Actor updated 4 times (policy_update_delay=4) as often as the Q networks, which I don't think is the intended behavior.

Possible Solution

Use a separate counter that specifically count updates (e.g., update_t rather than t) and incremented by +1 every update step. And use this in the train function e.g.:

     params, opt_states, act_loss_info = lax.cond(
            update_t % cfg.system.policy_update_delay == 0,  # TD 3 Delayed update support
            update_actor_and_alpha,
            # just return same params and opt_states and 0 for losses
            lambda params, opt_states, *_: (
                params,
                opt_states,
                {"actor_loss": 0.0, "alpha_loss": 0.0},
            ),
            params,
            opt_states,
            data,
            actor_key,
        )

LeoHink avatar Jan 13 '25 16:01 LeoHink

Hi @LeoHink great find, this is indeed a bug! Would you be able to put up a PR for this, your solution looks good to me :smile:

sash-a avatar Jan 14 '25 07:01 sash-a