Support for loss function with auxiliary data in linesearch
I have a loss function that returns (loss_value, extra_data). Native jax supports this kind of construct with jax.value_and_grad(loss_fn, has_aux=True) (doc). The differentiated function returns ((loss_value, extra_data), grad).
In optax, when using the linesearch algorithms (for example as part of L-BFGS), I can use optax.value_and_grad_from_state(loss_fn) (doc) which uses the optimizer state to save function evaluations done inside the linesearch. Unfortunately, the linesearch algorithms and optax.value_and_grad_from_state don't support auxiliary data.
I added support for this to the optax code. It works for my use case. Are you interested in merging this upstream? I don't have time for proper testing, documentation, etc though, so would appreciate getting some assistance.
Hello @ro0mquy,
I'd be happy to see how you handled it. I was not sure what would be the best solution to add this while keeping the API light. So if you have some example, I'd be happy to look at a PR.
Thanks!
Cool, I'll prepare a PR once I'm back from vacations in 1-2 weeks.
Hey I am picking this up again. How can I format my code according to the style of code base?
I submitted a draft PR https://github.com/google-deepmind/optax/pull/1177