jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

Complex gradients

Open GeoffNN opened this issue 4 years ago • 4 comments

Opening this as a nota bene. When optimizing over complex parameters, the gradient must be conjugated. Currently, all jaxopt optimizers would be incorrect on complex parameters, due to this.

Moreover, if any optimizer relies on a second order moments (eg Adam), it must also use the complex module squared instead of just the parameter squared. Current jaxopt solvers might be affected as well. I'm unsure of what implcit diff would do to complex parameters, but perhaps we could output a warning that it is currently probably incorrect.

I realized this while using Optax on a model with complex weights; thought it might be good to incorporate this in jaxopt solvers as well, as users might 1) not be aware of this and 2) it's really hard to debug on the user side.

GeoffNN avatar Feb 03 '22 23:02 GeoffNN

For instance, tree_l2_norm would currently give incorrect results on complex parameters.

https://github.com/google/jaxopt/blob/eb6e75dfee1d25cc2b206ad1410668c576bf6750/jaxopt/_src/tree_util.py#L84

GeoffNN avatar Feb 03 '22 23:02 GeoffNN

+1 on fixing this, thanks for catching

mblondel avatar Feb 05 '22 09:02 mblondel

I think this is an issue we should tackle soon because as you said this could fail silently. Do you want to tackle it?

mblondel avatar Jun 28 '22 09:06 mblondel

Hey! Sorry, I was interning this summer and off of github. I'll start checking this out!

GeoffNN avatar Sep 12 '22 18:09 GeoffNN