AMSGrad implementation differs from PyTorch/TensorFlow
Hello,
We noticed that models trained using the AMSGrad optimizer in Optax tend to yield slightly poorer results compared to the same models trained using PyTorch. The two implementations differ as follows:
Current Optax implementation:
PyTorch implementation (the implementation in TensorFlow follows the same algorithm):
Would it be possible to align the Optax implementation with the PyTorch/TensorFlow version? This would improve consistency across the different ML frameworks and possibly improve performance.
Hello @sklenard ,
Thank you for the issue, this is very well spotted. Clearly, we should try to align implementations. Optax's version makes in fact a bit more sense to me (the bias correction should be just after the ema computation). We can change that, but, if you have a code ready, could you test whether removing the bias correction (for both implementations) improves results?
Thanks again!
Hello @vroulet ,
Indeed, removing the bias correction on the 2nd moment in both versions aligns the 2 implementations and it actually corresponds to the original AMSGrad paper. I have tested it on a toy model (an MLP implemented in Flax, trained to learn a polynomial function) and it seems to improve the training :
The PyTorch reference (dots) corresponds to the equivalent model implemented and trained with PyTorch 2.6.0, using the same inital weights as the Flax model. The curves show the results for the different implementations of AMSGrad in Optax. If you want to check the code, I can share it of course!
That's very neat, thank you! So I'd propose to:
- add a note in the docs explaining the difference with pytorch
- add an option to remove bias correction (since it may be the most reasonable thing to do)
What do you think? Also do you agree that a priori the bias correction should be just after the ema computation?
Yes, it is a good idea to add an option to remove bias correction. I am wondering whether it would also make sense to also include an option to get the same behavior as seen in PyTorch and TensorFlow (at least for benchmarking purposes). I agree that, mathematically, bias correction should come after EMA and I don't really understand why this yileds the worst results.
Incidentally, I discovered that the AMSGrad algorithm, with bias correction after EMA, is described in the doc of PyTorch <= 2.5 (see PyTorch Documentation). However, this is a mistake in the documentation, as the code actually implements the bias correction after the max as pointed out in PyTorch Issue #142323.
Hi! I've opened a PR addressing this issue It adds an option to disable bias correction on the second moment