optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

OptaxMinimiser does not work with Optax-Linesearch

Open matillda123 opened this issue 9 months ago • 7 comments

Hey,

I was just trying to use optimistix.OptaxMinimiser() with the L-BFGS optimizer from optax. This however does not work, because the included linesearch requires additional inputs to the update-function. The code below results in the error:

TypeError: scale_by_backtracking_linesearch..update_fn() missing 3 required keyword-only arguments: 'value', 'grad', and 'value_fn'

import optimistix
import optax
import jax.numpy as jnp


def test_func(x, args):
    return jnp.sum(jnp.abs(x)**2)

# L-BFGS
solver=optimistix.OptaxMinimiser(optax.lbfgs(learning_rate=1), rtol=1e-6, atol=1e-6)

# Adam + Linesearch
solver=optimistix.OptaxMinimiser(
    optax.chain(optax.adam(learning_rate=1), optax.scale_by_backtracking_linesearch(25)), 
    rtol=1e-6, atol=1e-6)

solution=optimistix.minimise(test_func, solver, y0=jnp.ones(10))

I think fixing this should be relatively easy because the other optax optimizers dont care if additional kwargs are provided. E.g. below the optax.adam() optimizer runs regardless wether the inputs are provided or not.

start_learning_rate = 1e-1
optimizer = optax.adam(start_learning_rate)
params = jnp.ones(10)
opt_state = optimizer.init(params)


for _ in range(1000):
  value, grads = jax.value_and_grad(test_func)(params, args=None)

  updates, opt_state = optimizer.update(grads, opt_state) # standard usage
  #updates, opt_state = optimizer.update(grads, opt_state, params, value=value, grad=grads, value_fn=value) # with linesearch-inputs

  params = optax.apply_updates(params, updates)

Anyway since optimistix has its own linesearch options this is nothing crucial. But maybe there are also other optax-features which cause similar issues with OptaxMinimiser()?

matillda123 avatar Mar 09 '25 17:03 matillda123

Hi Matilda,

thanks for the issue!

For the points you are raising: do I understand correctly that you'd like optim.update to accept more keyword arguments in this line: https://github.com/patrick-kidger/optimistix/blob/482f9a26d8e3b5f2cc3927ffc8fc9d9b554c24ad/optimistix/_solver/optax.py#L100

I think this sounds reasonable, and looking at https://github.com/google-deepmind/optax/commit/913851292926ca7359175970fcaf3ba775b1fef5 it also seems like this would be behaviour that is fully supported on the optax side.

Anyway since optimistix has its own linesearch options this is nothing crucial.

optx.OptaxMinimiser does not accept a search attribute, though. So if you'd really like to use their line searches, then this tweak would be required. You might also be interested in hearing that we have our own L-BFGS in the works :)

But maybe there are also other optax-features which cause similar issues with OptaxMinimiser()?

We're thinking about projected and proximal gradient descent flavours, and you raising this issue actually gives me a little food for thought - it might be reasonable to handle everything through OptaxMinimiser instead, by supporting all the keywords that might be required for the projection they implement. I had written up a wrapper instead - which might still be reasonable for a more complex projection onto some manifold defined by a constraint function, which optax currently does not support.

@patrick-kidger WDYT?

johannahaffner avatar Mar 09 '25 23:03 johannahaffner

(optax projections: here)

johannahaffner avatar Mar 09 '25 23:03 johannahaffner

Judging from https://optax.readthedocs.io/en/stable/_collections/examples/lbfgs.html#linesearches-in-practice it should be pretty simple to add in the three extra arguments it expects.

I'd be happy to take a PR on this! (Including a test to be sure that we support this properly.)

patrick-kidger avatar Mar 10 '25 13:03 patrick-kidger

@bagibence @BalzaniEdoardo

johannahaffner avatar Mar 10 '25 18:03 johannahaffner

I will send a PR for this in the afternoon!

bagibence avatar Mar 12 '25 11:03 bagibence

Hi Matilda,

thanks for the issue!

For the points you are raising: do I understand correctly that you'd like optim.update to accept more keyword arguments in this line:

optimistix/optimistix/_solver/optax.py

Line 100 in 482f9a2

updates, new_opt_state = self.optim.update(grads, state.opt_state, y) I think this sounds reasonable, and looking at google-deepmind/optax@9138512 it also seems like this would be behaviour that is fully supported on the optax side.

Anyway since optimistix has its own linesearch options this is nothing crucial.

optx.OptaxMinimiser does not accept a search attribute, though. So if you'd really like to use their line searches, then this tweak would be required. You might also be interested in hearing that we have our own L-BFGS in the works :)

But maybe there are also other optax-features which cause similar issues with OptaxMinimiser()?

We're thinking about projected and proximal gradient descent flavours, and you raising this issue actually gives me a little food for thought - it might be reasonable to handle everything through OptaxMinimiser instead, by supporting all the keywords that might be required for the projection they implement. I had written up a wrapper instead - which might still be reasonable for a more complex projection onto some manifold defined by a constraint function, which optax currently does not support.

@patrick-kidger WDYT?

Hey guys,

The "L-BFGS" drop got me hooked! Is there a time-line or working code that I can use in the mean time? The optax one's line-search is based on lax.while_loop and is thus not reverse-mode differentiable.

SNMS95 avatar Mar 19 '25 10:03 SNMS95

The "L-BFGS" drop got me hooked! Is there a time-line or working code that I can use in the mean time?

We're working on something related in https://github.com/patrick-kidger/optimistix/pull/125. The AbstractQuasiNewtonUpdate could be subclassed to create different Hessian approximations, including L-BFGS. I have some prototype-y stuff lying around, but am currently focusing on constrained optimisation.

The optax one's line-search is based on lax.while_loop and is thus not reverse-mode differentiable.

That is very good to know - I was not aware. It seems we do have an unmet need there, then!

johannahaffner avatar Mar 19 '25 11:03 johannahaffner

Closing as fixed in #122.

patrick-kidger avatar Oct 17 '25 20:10 patrick-kidger