Add Implementation of _Lion Optimizer (Evolved sign Momentum)
Summary
I'd like to contribute an Optax Implementation of Lion Optimizer, i.e a gradient transformation and a convenience Lion(...) a wrapper in contrib that composes decoupled weight decay and learning-rate scaling. It tracks a single momentum and uses sign(...) of an interpolation for updates as described in the paper https://arxiv.org/abs/2302.06675
What will I include:
- Implementation file in
(Optax/contrib/_lion.py) - Test file in
(Optax/contrib/_lion_test.py) - a quick Note about fp16 behaviour and suggestions for recommended dtype handling
Request
- Guidance on, Would maintaniers be open to this style of Contributions placed under
Optax/contrib - Any specific tests, coding style or helper utils
- I can open a PR + Implementations/tests,
Thanks - I'm happy to iterate quickly based on feedback
I have raised a PR regd this @vroulet can you please take a look at it?
Linking the PR for visibility: https://github.com/google-deepmind/optax/pull/1438
Thanks a lot @rdyro for doing it
Did you check this https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.lion ?
Hey thanks a lot for letting me know,
I will re-implement _lion, with smooth_sign, in the existing implementation it uses a hard sign function (jnp.sign()), which has limitations
- Poor Gradient flow
- Discontinuity in training via instability
@vroulet let me know if i can do this implementation? Instead and thanks a lot for pointing lion implementation,
Can you point to a resource that proposes to use smooth sign in lion?
sure @rdyro https://www.researchgate.net/publication/385679808_RLion_A_Refined_Lion_Optimizer_for_Deep_Learning
this is the paper that uses smooth sign, - it replaces the discrete sign(-) update of lion optimizer with a continuous bounded function (arctan) to smooth out the fluctuation
please take a look at this and let me know, how would you like me to go ahead