optax icon indicating copy to clipboard operation
optax copied to clipboard

Support for loss function with auxiliary data in linesearch

Open ro0mquy opened this issue 1 year ago • 4 comments

I have a loss function that returns (loss_value, extra_data). Native jax supports this kind of construct with jax.value_and_grad(loss_fn, has_aux=True) (doc). The differentiated function returns ((loss_value, extra_data), grad).

In optax, when using the linesearch algorithms (for example as part of L-BFGS), I can use optax.value_and_grad_from_state(loss_fn) (doc) which uses the optimizer state to save function evaluations done inside the linesearch. Unfortunately, the linesearch algorithms and optax.value_and_grad_from_state don't support auxiliary data.

I added support for this to the optax code. It works for my use case. Are you interested in merging this upstream? I don't have time for proper testing, documentation, etc though, so would appreciate getting some assistance.

ro0mquy avatar Sep 10 '24 08:09 ro0mquy

Hello @ro0mquy,

I'd be happy to see how you handled it. I was not sure what would be the best solution to add this while keeping the API light. So if you have some example, I'd be happy to look at a PR.

Thanks!

vroulet avatar Sep 10 '24 15:09 vroulet

Cool, I'll prepare a PR once I'm back from vacations in 1-2 weeks.

ro0mquy avatar Sep 11 '24 10:09 ro0mquy

Hey I am picking this up again. How can I format my code according to the style of code base?

ro0mquy avatar Jan 16 '25 10:01 ro0mquy

I submitted a draft PR https://github.com/google-deepmind/optax/pull/1177

ro0mquy avatar Jan 16 '25 11:01 ro0mquy