optax icon indicating copy to clipboard operation
optax copied to clipboard

Add a bias_correction_v flag to scale_by_amsgrad to align with the original AMSGrad paper and Pytorch/tensorflow impl

Open vvsvictor opened this issue 3 months ago • 3 comments

Resolves #1389

vvsvictor avatar Sep 26 '25 14:09 vvsvictor

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

google-cla[bot] avatar Sep 26 '25 14:09 google-cla[bot]

Wouldn't setting bias_correction_v=False not skip applying bias b2 entirely now?

Is the point to not use the bias at all or change the order of the operations as discussed here: https://github.com/pytorch/pytorch/issues/142323

Pytorch applies this the bias in amsgrad to this day: https://github.com/pytorch/pytorch/blob/2164b661219ab0a76aa018e955ba3d8e8f99c083/torch/optim/adam.py#L509

But tensorflow does not (I think): https://github.com/keras-team/keras/blob/f6c4ac55692c132cd16211f4877fac6dbeead749/keras/src/optimizers/adam.py#L130-L150

rdyro avatar Oct 06 '25 17:10 rdyro

Wouldn't setting bias_correction_v=False not skip applying bias b2 entirely now?

Is the point to not use the bias at all or change the order of the operations as discussed here: pytorch/pytorch#142323

Pytorch applies this the bias in amsgrad to this day: https://github.com/pytorch/pytorch/blob/2164b661219ab0a76aa018e955ba3d8e8f99c083/torch/optim/adam.py#L509

But tensorflow does not (I think): https://github.com/keras-team/keras/blob/f6c4ac55692c132cd16211f4877fac6dbeead749/keras/src/optimizers/adam.py#L130-L150

I think optax original implementation is the one that makes most sense (doing the bias correction after taking the max does not make sense to me). However one could also simply remove the bias correction see plots in https://github.com/google-deepmind/optax/issues/1389. It seems to potentially improve and most importantly it aligns with the paper.

vroulet avatar Oct 10 '25 23:10 vroulet