SAC code doesn't appropriately implement target_q
The computation of target Q in the SERL SAC code, critic_loss_fn() has a potential bug.
In this file, if you set config['backup_entropy']=True, the term temperature * next_action_log_probs is subtracted from target_q. This is mathematically equivalent to
$$y(r,s',d) = r + \gamma*(1-d) \left[ \min_{1,2} Q(s',a') \right] - \alpha * \log \pi_{\theta}(a'|s)$$
where y(r,s',d) = target_q, r = batch['rewards'], $\gamma$ = config['discount'], (1-d) = batch['masks'], $\min_{1,2} Q(s',a')$ = target_next_min_q, $\alpha$ = temperature, $\log \pi_{\theta}(a'|s)$ = next_actions_log_probs, $a' \sim \pi(\cdot | s)$
But the formula for target_q should be
$$y(r,s',d) = r + \gamma*(1-d) \left[ \min_{1,2} Q(s',a') - \alpha * \log \pi_{\theta}(a'|s) \right] $$
i.e. the $\alpha * \log \pi_{\theta}(a'|s)$ term should also be multiplied by\gamma*(1-d). This is so the entropy term is appropriately weighted by the discount factor so that your value function calculations are accurate.
Sources: [1] SAC paper, see eq 3 for Value function [2] Spinning up RL by OpenAI, SAC pseudocode, see line 12 for computing target q values.
The fix is quite simple. You subtract the $\alpha * \log \pi_{\theta}(a'|s)$ term before you multiply by $\gamma*(1-d)$.
Perhaps this is the reason why the critic loss curve values in wandb remained positive while I was training the model with sac.(°ー°〃)