serl icon indicating copy to clipboard operation
serl copied to clipboard

SAC code doesn't appropriately implement target_q

Open gautams3 opened this issue 8 months ago • 1 comments

The computation of target Q in the SERL SAC code, critic_loss_fn() has a potential bug.

In this file, if you set config['backup_entropy']=True, the term temperature * next_action_log_probs is subtracted from target_q. This is mathematically equivalent to

$$y(r,s',d) = r + \gamma*(1-d) \left[ \min_{1,2} Q(s',a') \right] - \alpha * \log \pi_{\theta}(a'|s)$$

where y(r,s',d) = target_q, r = batch['rewards'], $\gamma$ = config['discount'], (1-d) = batch['masks'], $\min_{1,2} Q(s',a')$ = target_next_min_q, $\alpha$ = temperature, $\log \pi_{\theta}(a'|s)$ = next_actions_log_probs, $a' \sim \pi(\cdot | s)$

But the formula for target_q should be

$$y(r,s',d) = r + \gamma*(1-d) \left[ \min_{1,2} Q(s',a') - \alpha * \log \pi_{\theta}(a'|s) \right] $$

i.e. the $\alpha * \log \pi_{\theta}(a'|s)$ term should also be multiplied by\gamma*(1-d). This is so the entropy term is appropriately weighted by the discount factor so that your value function calculations are accurate.

Sources: [1] SAC paper, see eq 3 for Value function [2] Spinning up RL by OpenAI, SAC pseudocode, see line 12 for computing target q values.

The fix is quite simple. You subtract the $\alpha * \log \pi_{\theta}(a'|s)$ term before you multiply by $\gamma*(1-d)$.

gautams3 avatar May 14 '25 23:05 gautams3

Perhaps this is the reason why the critic loss curve values in wandb remained positive while I was training the model with sac.(°ー°〃)

Leon-god avatar Sep 12 '25 01:09 Leon-god