optax
optax copied to clipboard
Add missing projections from jaxopt
Related: https://github.com/google-deepmind/optax/issues/977
The following projections are present in jaxopt but missing in optax:
- [ ] sparse_simplex
- [x] hyperplane (now available here)
- [x] halfspace (now available here)
- [ ] affine_set
- [ ] polyhedron
- [ ] box_section
- [ ] transport
- [ ] birkhoff
I can work on some of these.
I am interested in making contributions (and learning as I have only recently started becoming adept with jax in general) to this.
So could someone potentially outline the skeleton for the best way to approach this?
Thanks!
@aymuos15 To see how Optax currently implements its projection functions, see the projections module.
To see how JAXopt implements the ones in the above list, you can click on one of them (e.g. projection_sparse_simplex) and then click on the button labeled [source], which will take you to the source code.
Note that I'm already adding a few at https://github.com/google-deepmind/optax/pull/1351, so you can work on some of the remaining ones, if you'd like.
Thank you so much @carlosgmartin .
I will get on with this my time morning.
Two questions.
I have started working on transport and birkhoff. - https://github.com/aymuos15/optax/tree/more_projections
- Do I create a separate PR, or PR into your fork?
- Dealing with non existent dependencies present in jaxopt but not optax
- Affine_set is dependent on EqualityConstrainedQP
- polyhedron is dependent on OSQP
- box is dependent on Bisection
Do I do the implementations of these or how do I carry forward? Thanks.
I suggest creating a separate PR on top of the current head and then merging after https://github.com/google-deepmind/optax/pull/1351 is merged (or waiting until https://github.com/google-deepmind/optax/pull/1351 is merged and then creating a PR on top of the new head).
You're right that some of these projections are dependent on JAXopt-specific solvers that aren't (yet) present in Optax. I'd suggest holding off on those for now, until Optax decides what to do about those solvers (see https://github.com/google-deepmind/optax/issues/977). Feel free to also chime into that discussion.
Perfect! Thanks a lot. Ill go through the solver issue and see if I can add anything to it.
I will just wait for your pr to be merged then. I think its much easier/cleaner that way.