optax
optax copied to clipboard
Adds Adan Optimizer
Closes #401 Implementation based on the official code
Thank you very much! If the paper claims stack up, this will be very useful to the jax community.
Btw there is pytorch reference code from the authors of the paper, https://github.com/sail-sg/Adan
Would you mind loading both the pytorch and the optax implementation in a colab and show that they match? when applying 5/10 steps with some dummy gradients as inputs it might highlight subtle differences that might be hard to spot from just staring at the code
Hey @mtthss, always happy to help :D Yeah, I based my code in the official implementation and here's the colab At the moment I have 2 questions:
- The current implementation seems to generate slightly different results, for example when using a 256x256 matrix and doing 1000 updates the norm of the difference between the pytorch and jax version is 2.4e-5 on the CPU, is that expected or should I keep looking for more divergences in the implementation?
- The default for the official implementation of Adan seems to use a different type of weight decay, where the params are divided by (1 + lr * wd) instead of multiplied by (1 - lr * wd), this causes a large difference in behavior, but I'm not exactly sure what is meant by the
no_prox
parameters that controls what kind of weight decay is being used
I've computed the relative error and its in the order of 10^-8 (though it still grows as we do more updates), thoughts? @mtthss
Hi there, I have been looking at this optimizer as well, and thought I'd chime in!
Firstly, thanks for the work that you have done. This looks to be a credible and nicely-written implementation.
Some notes:
-
L6
of the Adan algorithm says: $\mathbf{\eta_k}=\eta / (\sqrt{\mathbf{n_k} + \epsilon})$; however, in the released code, the $\epsilon$ is not within the square root. This is a discrepancy between the paper and published code: it is an open question which we should follow here. This was the same in v1 and v2 of the paper. I have submitted an issue — we can see what they say. We could add aneps_root
parameter to enable the user to set it how they like.2. I would change theweight_decay
default to 0.02, as in the paper. - why do you have a
mu_dtype
but not adelta_dtype
? Or, since they are both first-order accumulators, you may want to reuse the datatype. - you mention that there are two different implementations of the weight decay in the released code. I think that the
no_prox==True
condition is the one you have implemented, which fits nicely into anoptax.chain
in your code. However, in thecolab
you wrote, weight decay is set to0.
, which means that there is no difference in the conditions. The implementation which matches the paper is theno_prox==False
condition, where the learning rate is harder to factor out into an optax chain.
To deal with this final issue, we could either have:
- a
no_prox
condition that copies their implementation - we could implement as written in the paper (different to yours —
_scale_by_learning_rate
andtransform.add_decayed_weights
will not be chained — the update will be in one step), or just leave it as this.
I did some testing myself and it looks like your implementation only really diverges for non-zero weight decay and no_prox==False
, as expected.
Hi @Zach-ER! Thanks for the thorough response, for checking with the authors and running my code!
- From the authors' response to your issue it looks like their results were achieved with
eps
outside the square root, so I think we can leave it as it, wdyt? - Makes sense, I believe I'll resuse it
- Yeah I'm not sure about the best way to proceed here, maybe I can implement the no_prox condition on the transform and create to alias,
adan
andadan_no_prox
, wdyt?
Once again thanks for the comments :D
- I would prefer leaving the defaults as they are but also having an
eps_root
argument, defaulting to 0.0. This has the benefit of being a closer match to theadam
signature and also letting people implement it as in the paper (if they so choose). - 👍🏻 — just make sure the name is appropriate.
- I think we should definitely have a method that matches their default implementation. I would put a condition into the
adan
optimizer that matches the reference behaviour. This will be slightly fiddly. Do you want to have a try at this? If not, I would be happy to draft something.
minor nitpicking: docstring needs fixing for b1, b2 (and b3 needs adding). Something like
b1: Decay rate for the exponentially weighted average of gradients.
b2: Decay rate for the exponentially weighted average of difference of
gradients.
b3: Decay rate for the exponentially weighted average of the squared term.
Hi there,
I have posted a new issue in the original repo here. If they say their experiments were conducted with no_prox=False
, then I think we can ignore the other condition and your PR reproduces their algorithm and fits well style-wise with the rest of the codebase.
Will update this when the authors respond.
OK, the authors have responded.
Their experiments use no_prox=False
, the condition that you have not implemented, so I think we do need to implement that one and match their algorithm exactly.
what I would do:
- add a
use_proximal_operator
boolean argument, defaulting toTrue
, to match what's in the paper. (this is apparently whatprox
is short for). - if
False
, implement exactly as you've already done - if
True
, need a slightly less standard implementation
I wrote a jax
version that matches their implementation, based on your gist and codebase. It is here and matches their results completely with weight decay turned on.
If you could integrate what I've written into your PR, that would be great — if not, I will find some time to do it (but quite busy at the moment).
Again, thanks for your work — this will be a great addition to the library 🎖
Thank you very much for all the help! Right now I'm at RIIAA Ecuador, but I'll try to integrate your code as soon as I'm back home, around Sunday or Monday :)
So it does pass the alias tests, but it looks like sphinx is erroing now, any ideas of a quick fix (I'm not experienced with Sphinx) ? @Zach-ER
@joaogui1 it should be fixed now, could you update the PR?
@hbq1 update how? Should I merge main?
Should I merge main?
Yes 👍
@Zach-ER thanks a lot for your great comments! Since you are an experienced user of this optimiser, I was wondering if the current version looks good to you? :)
Yes, LGTM. Looking forward to trying it out some more 🙌🏻
Done @hbq1
Yes, I think that this could be a problem.
I agree that it is in a similar boat to the lookahead optimizer.
Any update on this?
@carlosgmartin no updates, this PR is currently orphaned. Do you want to take over?
@fabianp What changes, if any, need to be made to @joaogui1's PR?
And what's the consensus on @mkunesch's questions?
- A first step would be to update with main, currently there are some conflicts with the current head.
- I believe the issues highlighted by @mkunesch only apply to the proximal version
scale_by_proximal_adan
. I would suggest focusing first onscale_by_adan
which shouldn't have those issues if I understood correctly.
@fabianp Created a PR: #1090.
excellent, closing this PR then