OptaxMinimiser does not work with Optax-Linesearch
Hey,
I was just trying to use optimistix.OptaxMinimiser() with the L-BFGS optimizer from optax. This however does not work, because the included linesearch requires additional inputs to the update-function. The code below results in the error:
TypeError: scale_by_backtracking_linesearch.
.update_fn() missing 3 required keyword-only arguments: 'value', 'grad', and 'value_fn'
import optimistix
import optax
import jax.numpy as jnp
def test_func(x, args):
return jnp.sum(jnp.abs(x)**2)
# L-BFGS
solver=optimistix.OptaxMinimiser(optax.lbfgs(learning_rate=1), rtol=1e-6, atol=1e-6)
# Adam + Linesearch
solver=optimistix.OptaxMinimiser(
optax.chain(optax.adam(learning_rate=1), optax.scale_by_backtracking_linesearch(25)),
rtol=1e-6, atol=1e-6)
solution=optimistix.minimise(test_func, solver, y0=jnp.ones(10))
I think fixing this should be relatively easy because the other optax optimizers dont care if additional kwargs are provided. E.g. below the optax.adam() optimizer runs regardless wether the inputs are provided or not.
start_learning_rate = 1e-1
optimizer = optax.adam(start_learning_rate)
params = jnp.ones(10)
opt_state = optimizer.init(params)
for _ in range(1000):
value, grads = jax.value_and_grad(test_func)(params, args=None)
updates, opt_state = optimizer.update(grads, opt_state) # standard usage
#updates, opt_state = optimizer.update(grads, opt_state, params, value=value, grad=grads, value_fn=value) # with linesearch-inputs
params = optax.apply_updates(params, updates)
Anyway since optimistix has its own linesearch options this is nothing crucial. But maybe there are also other optax-features which cause similar issues with OptaxMinimiser()?
Hi Matilda,
thanks for the issue!
For the points you are raising: do I understand correctly that you'd like optim.update to accept more keyword arguments in this line: https://github.com/patrick-kidger/optimistix/blob/482f9a26d8e3b5f2cc3927ffc8fc9d9b554c24ad/optimistix/_solver/optax.py#L100
I think this sounds reasonable, and looking at https://github.com/google-deepmind/optax/commit/913851292926ca7359175970fcaf3ba775b1fef5 it also seems like this would be behaviour that is fully supported on the optax side.
Anyway since optimistix has its own linesearch options this is nothing crucial.
optx.OptaxMinimiser does not accept a search attribute, though. So if you'd really like to use their line searches, then this tweak would be required. You might also be interested in hearing that we have our own L-BFGS in the works :)
But maybe there are also other optax-features which cause similar issues with
OptaxMinimiser()?
We're thinking about projected and proximal gradient descent flavours, and you raising this issue actually gives me a little food for thought - it might be reasonable to handle everything through OptaxMinimiser instead, by supporting all the keywords that might be required for the projection they implement. I had written up a wrapper instead - which might still be reasonable for a more complex projection onto some manifold defined by a constraint function, which optax currently does not support.
@patrick-kidger WDYT?
(optax projections: here)
Judging from https://optax.readthedocs.io/en/stable/_collections/examples/lbfgs.html#linesearches-in-practice it should be pretty simple to add in the three extra arguments it expects.
I'd be happy to take a PR on this! (Including a test to be sure that we support this properly.)
@bagibence @BalzaniEdoardo
I will send a PR for this in the afternoon!
Hi Matilda,
thanks for the issue!
For the points you are raising: do I understand correctly that you'd like
optim.updateto accept more keyword arguments in this line:optimistix/optimistix/_solver/optax.py
Line 100 in 482f9a2
updates, new_opt_state = self.optim.update(grads, state.opt_state, y) I think this sounds reasonable, and looking at google-deepmind/optax@9138512 it also seems like this would be behaviour that is fully supported on the optax side.
Anyway since optimistix has its own linesearch options this is nothing crucial.
optx.OptaxMinimiserdoes not accept a search attribute, though. So if you'd really like to use their line searches, then this tweak would be required. You might also be interested in hearing that we have our own L-BFGS in the works :)But maybe there are also other optax-features which cause similar issues with
OptaxMinimiser()?We're thinking about projected and proximal gradient descent flavours, and you raising this issue actually gives me a little food for thought - it might be reasonable to handle everything through
OptaxMinimiserinstead, by supporting all the keywords that might be required for the projection they implement. I had written up a wrapper instead - which might still be reasonable for a more complex projection onto some manifold defined by a constraint function, which optax currently does not support.@patrick-kidger WDYT?
Hey guys,
The "L-BFGS" drop got me hooked!
Is there a time-line or working code that I can use in the mean time?
The optax one's line-search is based on lax.while_loop and is thus not reverse-mode differentiable.
The "L-BFGS" drop got me hooked! Is there a time-line or working code that I can use in the mean time?
We're working on something related in https://github.com/patrick-kidger/optimistix/pull/125. The AbstractQuasiNewtonUpdate could be subclassed to create different Hessian approximations, including L-BFGS. I have some prototype-y stuff lying around, but am currently focusing on constrained optimisation.
The optax one's line-search is based on
lax.while_loopand is thus not reverse-mode differentiable.
That is very good to know - I was not aware. It seems we do have an unmet need there, then!
Closing as fixed in #122.