equinox icon indicating copy to clipboard operation
equinox copied to clipboard

using tfp.substrates.jax.optimizer.lbfgs_minimize with Equinox

Open raj-brown opened this issue 1 year ago • 6 comments

Dear All- I want to use LBFGS optimizer and was wondering If there is any example using Equinox Neural Network model with tfp.substrates.jax.optimizer.lbfgs_minimize optimizer.

Thanks!

raj-brown avatar Dec 05 '24 23:12 raj-brown

I am not aware of any examples using TFP + equinox, but it should be possible given that equinox generally operates at the jax level, rather than as a wrapper. There are lots of examples using optax + equinox, and optax also has a LBFGS example (https://optax.readthedocs.io/en/stable/_collections/examples/lbfgs.html) so that might be a good starting place (other implementations of potential use are https://jaxopt.github.io/stable/unconstrained.html and https://docs.kidger.site/optimistix/api/minimise/#optimistix.BFGS)

lockwo avatar Dec 06 '24 01:12 lockwo

Thank you @lockwo and @patrick-kidger It worked for me. I had another question. For the example https://docs.kidger.site/equinox/all-of-equinox/ If I want to maximize the loss function only for extra_bias and minimize it for the rest parameters? How do I selectively choose the parameters for min/max operation. I will really appreciate for your suggestions @patrick-kidger @lockwo. Thanks!

raj-brown avatar Dec 09 '24 04:12 raj-brown

To do what you're describing there's a couple ways (since I assume you don't much care about that specific problem, but want to apply it to your, potentially much more complicated use case). 1) you could modify the loss function/the function being differentiated such that the gradient is naturally what you want (e.g. if the loss function had two terms, you could potentially change the sign on the one dependent on certain parameters), 2) you could make the parameters a custom layer and write a custom gradient rule to just negate it (not sure where this level of work would be needed, but its possible), 3) (what I would probably do in most situations), just compute the gradient and then add an extra step to multiply that component of the gradient by -1 before applying the gradient.

lockwo avatar Dec 09 '24 05:12 lockwo

Thanks @lockw. I will take the choice 3. I just had one doubt how to choose the gradient correspondint to a specific parameters in the pytree . For the same example how would I know the gradints corresponding to extra_bias pytree...Thank you very much

raj-brown avatar Dec 09 '24 05:12 raj-brown

The gradients is a pytree of the same structure as the parameters, so wherever in the parameter pytree the bias is, same in the gradients

lockwo avatar Dec 09 '24 05:12 lockwo

thank you @lockwo

raj-brown avatar Dec 09 '24 05:12 raj-brown