optax icon indicating copy to clipboard operation
optax copied to clipboard

Adds Adan Optimizer

Open joaogui1 opened this issue 2 years ago • 21 comments

Closes #401 Implementation based on the official code

joaogui1 avatar Sep 04 '22 12:09 joaogui1

Thank you very much! If the paper claims stack up, this will be very useful to the jax community.

Btw there is pytorch reference code from the authors of the paper, https://github.com/sail-sg/Adan

Would you mind loading both the pytorch and the optax implementation in a colab and show that they match? when applying 5/10 steps with some dummy gradients as inputs it might highlight subtle differences that might be hard to spot from just staring at the code

mtthss avatar Sep 05 '22 14:09 mtthss

Hey @mtthss, always happy to help :D Yeah, I based my code in the official implementation and here's the colab At the moment I have 2 questions:

  1. The current implementation seems to generate slightly different results, for example when using a 256x256 matrix and doing 1000 updates the norm of the difference between the pytorch and jax version is 2.4e-5 on the CPU, is that expected or should I keep looking for more divergences in the implementation?
  2. The default for the official implementation of Adan seems to use a different type of weight decay, where the params are divided by (1 + lr * wd) instead of multiplied by (1 - lr * wd), this causes a large difference in behavior, but I'm not exactly sure what is meant by the no_prox parameters that controls what kind of weight decay is being used

joaogui1 avatar Sep 05 '22 20:09 joaogui1

I've computed the relative error and its in the order of 10^-8 (though it still grows as we do more updates), thoughts? @mtthss

joaogui1 avatar Sep 12 '22 14:09 joaogui1

Hi there, I have been looking at this optimizer as well, and thought I'd chime in!

Firstly, thanks for the work that you have done. This looks to be a credible and nicely-written implementation.

Some notes:

  1. L6 of the Adan algorithm says: $\mathbf{\eta_k}=\eta / (\sqrt{\mathbf{n_k} + \epsilon})$; however, in the released code, the $\epsilon$ is not within the square root. This is a discrepancy between the paper and published code: it is an open question which we should follow here. This was the same in v1 and v2 of the paper. I have submitted an issue — we can see what they say. We could add an eps_root parameter to enable the user to set it how they like.2. I would change the weight_decay default to 0.02, as in the paper.
  2. why do you have a mu_dtype but not a delta_dtype? Or, since they are both first-order accumulators, you may want to reuse the datatype.
  3. you mention that there are two different implementations of the weight decay in the released code. I think that the no_prox==True condition is the one you have implemented, which fits nicely into an optax.chain in your code. However, in the colab you wrote, weight decay is set to 0., which means that there is no difference in the conditions. The implementation which matches the paper is the no_prox==False condition, where the learning rate is harder to factor out into an optax chain.

To deal with this final issue, we could either have:

  • a no_prox condition that copies their implementation
  • we could implement as written in the paper (different to yours — _scale_by_learning_rate and transform.add_decayed_weights will not be chained — the update will be in one step), or just leave it as this.

I did some testing myself and it looks like your implementation only really diverges for non-zero weight decay and no_prox==False, as expected.

Zach-ER avatar Sep 14 '22 15:09 Zach-ER

Hi @Zach-ER! Thanks for the thorough response, for checking with the authors and running my code!

  1. From the authors' response to your issue it looks like their results were achieved with eps outside the square root, so I think we can leave it as it, wdyt?
  2. Makes sense, I believe I'll resuse it
  3. Yeah I'm not sure about the best way to proceed here, maybe I can implement the no_prox condition on the transform and create to alias, adan and adan_no_prox, wdyt?

Once again thanks for the comments :D

joaogui1 avatar Sep 16 '22 11:09 joaogui1

  1. I would prefer leaving the defaults as they are but also having an eps_root argument, defaulting to 0.0. This has the benefit of being a closer match to the adam signature and also letting people implement it as in the paper (if they so choose).
  2. 👍🏻 — just make sure the name is appropriate.
  3. I think we should definitely have a method that matches their default implementation. I would put a condition into the adan optimizer that matches the reference behaviour. This will be slightly fiddly. Do you want to have a try at this? If not, I would be happy to draft something.

minor nitpicking: docstring needs fixing for b1, b2 (and b3 needs adding). Something like

    b1: Decay rate for the exponentially weighted average of gradients.
    b2: Decay rate for the exponentially weighted average of difference of
      gradients.
    b3: Decay rate for the exponentially weighted average of the squared term.

Zach-ER avatar Sep 21 '22 08:09 Zach-ER

Hi there, I have posted a new issue in the original repo here. If they say their experiments were conducted with no_prox=False, then I think we can ignore the other condition and your PR reproduces their algorithm and fits well style-wise with the rest of the codebase.

Will update this when the authors respond.

Zach-ER avatar Sep 29 '22 08:09 Zach-ER

OK, the authors have responded. Their experiments use no_prox=False, the condition that you have not implemented, so I think we do need to implement that one and match their algorithm exactly.

what I would do:

  1. add a use_proximal_operator boolean argument, defaulting to True, to match what's in the paper. (this is apparently what prox is short for).
  2. if False, implement exactly as you've already done
  3. if True, need a slightly less standard implementation

I wrote a jax version that matches their implementation, based on your gist and codebase. It is here and matches their results completely with weight decay turned on.

If you could integrate what I've written into your PR, that would be great — if not, I will find some time to do it (but quite busy at the moment).

Again, thanks for your work — this will be a great addition to the library 🎖

Zach-ER avatar Sep 30 '22 10:09 Zach-ER

Thank you very much for all the help! Right now I'm at RIIAA Ecuador, but I'll try to integrate your code as soon as I'm back home, around Sunday or Monday :)

joaogui1 avatar Sep 30 '22 19:09 joaogui1

So it does pass the alias tests, but it looks like sphinx is erroing now, any ideas of a quick fix (I'm not experienced with Sphinx) ? @Zach-ER

joaogui1 avatar Oct 14 '22 12:10 joaogui1

@joaogui1 it should be fixed now, could you update the PR?

hbq1 avatar Oct 20 '22 13:10 hbq1

@hbq1 update how? Should I merge main?

joaogui1 avatar Oct 20 '22 14:10 joaogui1

Should I merge main?

Yes 👍

hbq1 avatar Oct 20 '22 14:10 hbq1

@Zach-ER thanks a lot for your great comments! Since you are an experienced user of this optimiser, I was wondering if the current version looks good to you? :)

hbq1 avatar Oct 20 '22 19:10 hbq1

Yes, LGTM. Looking forward to trying it out some more 🙌🏻

Zach-ER avatar Oct 21 '22 13:10 Zach-ER

Done @hbq1

joaogui1 avatar Oct 21 '22 21:10 joaogui1

Yes, I think that this could be a problem.

I agree that it is in a similar boat to the lookahead optimizer.

Zach-ER avatar Oct 25 '22 10:10 Zach-ER

Any update on this?

carlosgmartin avatar Mar 21 '24 06:03 carlosgmartin

@carlosgmartin no updates, this PR is currently orphaned. Do you want to take over?

fabianp avatar Mar 21 '24 08:03 fabianp

@fabianp What changes, if any, need to be made to @joaogui1's PR?

And what's the consensus on @mkunesch's questions?

carlosgmartin avatar Mar 27 '24 14:03 carlosgmartin

  1. A first step would be to update with main, currently there are some conflicts with the current head.
  2. I believe the issues highlighted by @mkunesch only apply to the proximal version scale_by_proximal_adan. I would suggest focusing first on scale_by_adan which shouldn't have those issues if I understood correctly.

fabianp avatar Mar 27 '24 15:03 fabianp

@fabianp Created a PR: #1090.

carlosgmartin avatar Oct 04 '24 01:10 carlosgmartin

excellent, closing this PR then

fabianp avatar Oct 04 '24 07:10 fabianp