optax icon indicating copy to clipboard operation
optax copied to clipboard

Add missing projections from jaxopt

Open carlosgmartin opened this issue 7 months ago • 6 comments

Related: https://github.com/google-deepmind/optax/issues/977

The following projections are present in jaxopt but missing in optax:

I can work on some of these.

carlosgmartin avatar Apr 27 '25 05:04 carlosgmartin

I am interested in making contributions (and learning as I have only recently started becoming adept with jax in general) to this.

So could someone potentially outline the skeleton for the best way to approach this?

Thanks!

aymuos15 avatar Jun 19 '25 08:06 aymuos15

@aymuos15 To see how Optax currently implements its projection functions, see the projections module.

To see how JAXopt implements the ones in the above list, you can click on one of them (e.g. projection_sparse_simplex) and then click on the button labeled [source], which will take you to the source code.

Note that I'm already adding a few at https://github.com/google-deepmind/optax/pull/1351, so you can work on some of the remaining ones, if you'd like.

carlosgmartin avatar Jun 21 '25 23:06 carlosgmartin

Thank you so much @carlosgmartin .

I will get on with this my time morning.

aymuos15 avatar Jun 21 '25 23:06 aymuos15

Two questions.

I have started working on transport and birkhoff. - https://github.com/aymuos15/optax/tree/more_projections

  1. Do I create a separate PR, or PR into your fork?
  2. Dealing with non existent dependencies present in jaxopt but not optax
    • Affine_set is dependent on EqualityConstrainedQP
    • polyhedron is dependent on OSQP
    • box is dependent on Bisection

Do I do the implementations of these or how do I carry forward? Thanks.

aymuos15 avatar Jun 22 '25 11:06 aymuos15

I suggest creating a separate PR on top of the current head and then merging after https://github.com/google-deepmind/optax/pull/1351 is merged (or waiting until https://github.com/google-deepmind/optax/pull/1351 is merged and then creating a PR on top of the new head).

You're right that some of these projections are dependent on JAXopt-specific solvers that aren't (yet) present in Optax. I'd suggest holding off on those for now, until Optax decides what to do about those solvers (see https://github.com/google-deepmind/optax/issues/977). Feel free to also chime into that discussion.

carlosgmartin avatar Jun 22 '25 17:06 carlosgmartin

Perfect! Thanks a lot. Ill go through the solver issue and see if I can add anything to it.

I will just wait for your pr to be merged then. I think its much easier/cleaner that way.

aymuos15 avatar Jun 22 '25 17:06 aymuos15