using tfp.substrates.jax.optimizer.lbfgs_minimize with Equinox
Dear All-
I want to use LBFGS optimizer and was wondering If there is any example using Equinox Neural Network model with tfp.substrates.jax.optimizer.lbfgs_minimize optimizer.
Thanks!
I am not aware of any examples using TFP + equinox, but it should be possible given that equinox generally operates at the jax level, rather than as a wrapper. There are lots of examples using optax + equinox, and optax also has a LBFGS example (https://optax.readthedocs.io/en/stable/_collections/examples/lbfgs.html) so that might be a good starting place (other implementations of potential use are https://jaxopt.github.io/stable/unconstrained.html and https://docs.kidger.site/optimistix/api/minimise/#optimistix.BFGS)
Thank you @lockwo and @patrick-kidger It worked for me. I had another question.
For the example https://docs.kidger.site/equinox/all-of-equinox/ If I want to maximize the loss function only for extra_bias and minimize it for the rest parameters? How do I selectively choose the parameters for min/max operation. I will really appreciate for your suggestions @patrick-kidger @lockwo. Thanks!
To do what you're describing there's a couple ways (since I assume you don't much care about that specific problem, but want to apply it to your, potentially much more complicated use case). 1) you could modify the loss function/the function being differentiated such that the gradient is naturally what you want (e.g. if the loss function had two terms, you could potentially change the sign on the one dependent on certain parameters), 2) you could make the parameters a custom layer and write a custom gradient rule to just negate it (not sure where this level of work would be needed, but its possible), 3) (what I would probably do in most situations), just compute the gradient and then add an extra step to multiply that component of the gradient by -1 before applying the gradient.
Thanks @lockw. I will take the choice 3. I just had one doubt how to choose the gradient correspondint to a specific parameters in the pytree . For the same example how would I know the gradints corresponding to extra_bias pytree...Thank you very much
The gradients is a pytree of the same structure as the parameters, so wherever in the parameter pytree the bias is, same in the gradients
thank you @lockwo