Add a bias_correction_v flag to scale_by_amsgrad to align with the original AMSGrad paper and Pytorch/tensorflow impl
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
Wouldn't setting bias_correction_v=False not skip applying bias b2 entirely now?
Is the point to not use the bias at all or change the order of the operations as discussed here: https://github.com/pytorch/pytorch/issues/142323
Pytorch applies this the bias in amsgrad to this day: https://github.com/pytorch/pytorch/blob/2164b661219ab0a76aa018e955ba3d8e8f99c083/torch/optim/adam.py#L509
But tensorflow does not (I think): https://github.com/keras-team/keras/blob/f6c4ac55692c132cd16211f4877fac6dbeead749/keras/src/optimizers/adam.py#L130-L150
Wouldn't setting
bias_correction_v=Falsenot skip applying bias b2 entirely now?Is the point to not use the bias at all or change the order of the operations as discussed here: pytorch/pytorch#142323
Pytorch applies this the bias in amsgrad to this day: https://github.com/pytorch/pytorch/blob/2164b661219ab0a76aa018e955ba3d8e8f99c083/torch/optim/adam.py#L509
But tensorflow does not (I think): https://github.com/keras-team/keras/blob/f6c4ac55692c132cd16211f4877fac6dbeead749/keras/src/optimizers/adam.py#L130-L150
I think optax original implementation is the one that makes most sense (doing the bias correction after taking the max does not make sense to me). However one could also simply remove the bias correction see plots in https://github.com/google-deepmind/optax/issues/1389. It seems to potentially improve and most importantly it aligns with the paper.