hfta
hfta copied to clipboard
[Optim] Upgrade adam to 1.9
- [x] Tested
- [x] Formatted with YAPF
PyTorch 1.9 uses a functional API to perform Adam computation. However, it seems like we don't have a functional API for those right now. So I didn't create such new API and directly modified the original code.
Note that although pytorch 1.9 added a doc saying that it uses a different L2 regularization (see this), I didn't find any code difference...
Please do review this change a bit more carefully.
Thanks
Actually, the plan for the optimizers in 1.9 is to use torch._foreach_xxx_
for performance reasons. You can think of it as combining intra-model and inter-model horizontal fusion together. Some context:
https://github.com/pytorch/pytorch/issues/38655
https://github.com/pytorch/pytorch/tree/master/torch/optim/_multi_tensor
I'm wondering if you would be able to take an initiative on this?
I can try to use the foreach API, following this.
Just to confirm, is this feature stable and released in 1.9? Or is it still under development?
I believe it's still under development, but looking at https://github.com/pytorch/pytorch/tree/master/torch/optim/_multi_tensor it seems like it's pretty close to release.