Proximal gradient and L-BFGS
Hi! Thank you for making Optimistix!
We're currently relying on JAXopt in NeMoS, but we're looking to transition to Optimistix (+ Optax). For this, we would need implementations of proximal gradient descent and L-BFGS. We saw that there are plans to implement L-BFGS and some discussion about constrained optimization. Are these currently in development by any chance? In case you're interested, we'd be happy to contribute. We're still getting familiar with the libraries and would appreciate guidance on how best to approach this. Some initial thoughts:
- Proximal gradient: I made a toy implementation for Lasso and Ridge that chains
optax.sgdand a step imitating Optax's projections which seems to work withoptimistix.minimise. Do you think this makes sense, or should the proximal operator be handled inside Optimistix instead? - L-BFGS: In my experiments, wrapping
optax.lbfgswithoptimistix.OptaxMinimiserruns without Optax’s line search but fails when it’s enabled.- Would it make sense and be feasible to turn
OptaxMinimiserinto a descent object, then combine it with Optimistix’s searches? - Maybe modify the wrapper to support Optax’s line search transformations?
- Or drop Optax and write the required descent step in Optimistix?
- Would it make sense and be feasible to turn
@BalzaniEdoardo, @billbrod, @sjvenditto, @gviejo
Hi @bagibence,
yes, constrained optimisation is currently in development! We have a few new descents and searches implemented. L-BFGS is also on the list, but hasn't been something I have actively worked on while implementing the constrained searches and descents.
How about we discuss these plans on a video call?
Thank you for the quick response, that sounds fantastic!
Absolutely, I will send you an email about the video call.
@bagibence and @BalzaniEdoardo, I've opened a draft PR for bounded and constrained optimisation, in which I have included a projected gradient descent that is very similar to what optax and jaxOPT do: take a step, and then project. I have included a box projection as an example boundary map, following this template other things such as L1 and L2 balls and spheres should be easily implementable.
Since you mention proximal gradient descent, that is actually a little different - if IIUC this entails forming a proximal operator with the objective function and the projection operator.
Looking forward to getting to know you both on Monday :)