jaxopt
jaxopt copied to clipboard
Problem differentiating through `solver.run` in `OptaxSolver`
I've been trying to use OptaxSolver
to perform a simple function minimization, since I want to differentiate through it's solution (the fixed point of the solver), but ran into an issue I'm not familiar with.
Here's a MWE for the error message:
import jax
import jax.scipy as jsp
from jaxopt import OptaxSolver
import optax
def pipeline(param_for_grad, data):
def to_minimize(latent):
return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)
solver = OptaxSolver(fun=to_minimize, opt=optax.adam(3e-4), implicit_diff=True)
initial, _ = solver.init(init_params = 5.)
result, _ = solver.run(init_params = initial)
return result
jax.value_and_grad(pipeline)(2., data=6.)
which yields this error:
CustomVJPException: Detected differentiation of a custom_vjp function with respect to a closed-over value. That isn't supported because the custom VJP rule only specifies how to differentiate the custom_vjp function with respect to explicit input parameters. Try passing the closed-over value into the custom_vjp function as an argument, and adapting the custom_vjp fwd and bwd rules.
My versions are:
jax==0.2.20
jaxlib==0.1.71
jaxopt==0.0.1
optax==0.0.9
Am I doing something very silly? I guess I'm also wondering if this example within the scope of the solver API? I noticed that this doesn't occur with solver.update
, just with solver.run
.
Thanks :)
Hi,
I think only a couple of small changes would be needed.
To use implicit differentiation with solver.run
, you should (1) expose the args
with respect to which you'd like to differentiate the solver's solution explicitly in the signature of fun
and (2) avoid using keyword arguments in the call to solver.run
.
In your MWE:
def pipeline(param_for_grad, data):
def to_minimize(latent, param_for_grad):
return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)
solver = OptaxSolver(fun=to_minimize, opt=optax.adam(5e-2), implicit_diff=True)
initial, _ = solver.init(init_params = 5.)
result, _ = solver.run(initial, param_for_grad)
return result
jax.value_and_grad(pipeline)(2., data=6.)
P.S. I also made a small change to the learning rate so that Adam converges in this example with the default maximum number of steps.
Thanks @phinate for the question and @fllinares for the answer!
Indeed, as Felipe explained, your param_for_grad
was in the scope (this is what is meant by closed-over value) but it wasn't an explicit argument of run
.
By the way, since run
calls init
for you, the line
initial, _ = solver.init(init_params = 5.)
is not needed. You can just set initial = 5
and then call run(initial, params_for_grad)
.
We are working on a documentation, hopefully these things will become clearer soon.
Thanks both @mblondel & @fllinares for your replies, and the helpful information!
I'm struggling a little with this because the suggestion of moving param_for_grad
into the to_minimize
call explicitly is a bit cumbersome for my use case; the way I'm actually making this objective function looks more like:
def setup_objective(param_for_grad, **kwargs):
to_minimize = complicated_function(param_for_grad, **kwargs)
return to_minimize
def pipeline(param_for_grad, **kwargs):
obj = setup_objective(param_for_grad, **kwargs)
solver = OptaxSolver(fun=obj, opt=optax.adam(5e-2), implicit_diff=True)
... etc ...
return result
To parametrize directly with param_for_grad
would mean that I would have to construct the objective via complicated_function
every time it was called in the minimization loop, when strictly this doesn't change with respect to param_for_grad
during the minimization.
Am I missing something here in terms of nicely setting up this problem? Or for implicit diff, do I really need to be explicit in the way you described, even though param_for_grad
is only used in the construction of the objective, and not for its evaluation?
Thanks again for the quick response earlier, and sorry if this is somehow unclear!
We decided to use explicit variables because with closed-over-variables there is no way to tell which need to be differentiated and which don't. This is problematic if you have several big variables in your scope, such as data matrices.
By the way, you can also take a look at jax.lax.custom_root
, which supports closed-over-variable. CC @shoyer
thanks @mblondel - is this something that could be in scope for jaxpot
later on to add closed-over variables? We can certainly provide metadata which variables require diffing and which don't
Could you sketch how this would look like on the user side?
could we co-opt the static_args
-like API for this.?
Edit: I guess this is equivalent to the argnums
kwarg .. so would that be sufficient?
def pipeline(param_for_grad, data):
def to_minimize(latent, param_for_grad):
return -jsp.stats.norm.logpdf(data, loc=param_for_grad*latent, scale=1)
solver = OptaxSolver(fun=to_minimize, opt=optax.adam(5e-2), implicit_diff=True)
initial, _ = solver.init(init_params = 5.)
result, _ = solver.run(initial, param_for_grad)
return result
pipeline = jaxopt.annotate(pipeline, diff_args = (0,))
jax.value_and_grad(pipeline)(2., data=6.)
We could support handling closed-over-variables via jax.closure_convert
(like jax.lax.custom_root
), but the tradeoff is that it requires tracing Python functions to a JAXpr. This means you can't do dynamic control flow/debugging with Python.
The ideal solution would probably be either to (1) encourage users to use closure_convert
themselves, or (2) possibly add an optional argument to allow for opting into automatic closure conversion, e.g., closure_convert=True
.
thanks @shoyer - can we use clcosure_convert
now to achieve the desired behavior?
Hi again all -- we followed @shoyer's suggestion of using jax.closure_convert
using a function pretty much ripped from the jax
docs:
def _minimize(objective_fn, lhood_pars, lr):
converted_fn, aux_pars = jax.closure_convert(objective_fn, lhood_pars)
# aux_pars seems to be empty, took that line from docs example
solver = OptaxSolver(fun=converted_fn, opt=optax.adam(lr), implicit_diff=True)
return solver.run(lhood_pars, *aux_pars)[0]
where objective_fn
is usually created on-the-fly with the aforementioned setup_objective
function.
Using this does allow autodiff with no errors, but we encounter a pretty substantial slowdown compared to wrapping an equivalent Adam optimiser using this more explicit implementation of the two-phase method. We were hoping to transition away from this in the interest of keeping up with jax
releases and other software that also follows jax
, as well as the far more active effort in jaxopt
.
As a side note, we do have an additional performance bottleneck coming from external software constraints that is hard to decorrelate from the change to jaxopt
(lack of ability to JIT some parts of the pipeline due to changing jax
version), but based on comparisons removing the JIT from the old program, I don't think it's nearly enough to explain the ~10x slowdown.
Is there any expected drop in performance from using closure_convert
, perhaps given my previous statements on the complexity of the setup_objective
function?
Closure conversion relies on JAX’s jaxpr interpreter, which is much slower than Python’s interpreter. If you can’t JIT the entire thing, that probably explains the performance slow down.
On Sun, Oct 10, 2021 at 4:58 AM Nathan Simpson @.***> wrote:
Hi again all -- we followed @shoyer https://github.com/shoyer's suggestion of using jax.closure_convert using a function pretty much ripped from the jax docs:
def _minimize(objective_fn, lhood_pars, lr): converted_fn, aux_pars = jax.closure_convert(objective_fn, lhood_pars) # aux_pars seems to be empty, took that line from docs example solver = OptaxSolver(fun=converted_fn, opt=optax.adam(lr), implicit_diff=True) return solver.run(lhood_pars, *aux_pars)[0]
where objective_fn is usually created on-the-fly with the aforementioned setup_objective function.
Using this does allow autodiff with no errors, but we encounter a pretty substantial slowdown compared to wrapping an equivalent Adam optimiser using this more explicit implementation of the two-phase method https://github.com/gehring/fax/blob/1bc0f4ddf4f0d54370b0c04282de49aa85685791/fax/implicit/twophase.py#L79. We were hoping to transition away from this in the interest of keeping up with jax releases and other software that also follows jax, as well as the far more active effort in jaxopt.
As a side note, we do have an additional performance bottleneck coming from external software constraints that is hard to decorrelate from the change to jaxopt (lack of ability to JIT some parts of the pipeline due to changing jax version), but based on comparisons removing the JIT from the old program, I don't think it's nearly enough to explain the ~10x slowdown.
Is there any expected drop in performance from using closure_convert, perhaps given my previous statements on the complexity of the setup_objective function?
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jaxopt/issues/31#issuecomment-939469490, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAJJFVUS4EAJRU62MJAD5H3UGF5XDANCNFSM5EGS4IYA .
I'm not sure what your setup_objective
function is doing but I would try to decompose it
def pipeline(param_for_grad, latent):
res = intermediary_step(param_for_grad, **kwargs)
def objective_fun(params, intermediary_result):
[...] # do not use param_for_grad or res here!
solver = OptaxSolver(fun=objective_fun, opt=optax.adam(5e-2), implicit_diff=True)
return solver.run(init_params, intermediary_result=res * latent).params
jax.jacobian(pipeline)(param_for_grad, latent)
The key idea is to use function composition so that the chain rule will apply. You may have to tweak it to your problem but you get the idea.
Any follow up on this? Does your objective seem decomposable in the way I describe?
Any follow up on this? Does your objective seem decomposable in the way I describe?
Have just thought about this a bit -- I'm not 100% if this would work, but one potential resolution to this for us in terms of decomposing the problem could be to build our statistical model (expensive boilerplate) from which we want to call a logpdf
method, and then construct the objective like this
def pipeline(param_for_grad):
res = model(param_for_grad)
def objective_fun(params, model):
return model.logpdf(params)
solver = OptaxSolver(fun=objective_fun, opt=optax.adam(5e-2), implicit_diff=True)
return solver.run(init_params, model=model).params
jax.jacobian(pipeline)(param_for_grad)
Provided this model was registered as a pytree, do you think this would resolve the problem? It's not something we've implemented, but could be if this would work.
Just an update on this: we've managed to get the jit working on our side with closure_convert
, and we see the performance recover, so @shoyer got it right on that count despite my (incorrect) assumption -- thanks!
If it's helpful, I'd be happy to summarise this thread as a small entry into the documentation via a PR @mblondel, since it could come up again for other users with similar use cases.
Thanks both for the fast and attentive help!
What would the code snippet look like?
hi again @mblondel, sorry to resurrect this from the dead -- my solution that uses closure_convert
randomly started leaking tracers, with one of the jax/jaxlib
updates and it's a bit of a nightmare to debug. Luckily, I found a fairly simple MWE:
from functools import partial
import jax
import jax.numpy as jnp
import jaxopt
import optax
# dummy model for test purposes
class Model:
x: jax.Array
def __init__(self, x) -> None:
self.x = x
def logpdf(self, pars, data):
return jnp.sum(pars*data*self.x)
@partial(jax.jit, static_argnames=["objective_fn"])
def _minimize(
objective_fn,
init_pars,
lr,
):
# this is the line added from our discussion above
converted_fn, aux_pars = jax.closure_convert(objective_fn, init_pars)
# aux_pars seems to be empty -- would have assumed it was the closed-over vals or similar?
solver = jaxopt.OptaxSolver(
fun=converted_fn, opt=optax.adam(lr), implicit_diff=True, maxiter=5000
)
return solver.run(init_pars, *aux_pars)[0]
@partial(jax.jit, static_argnames=["model"])
def fit(
data,
model,
init_pars,
lr = 1e-3,
):
def fit_objective(pars):
return -model.logpdf(pars, data)
fit_res = _minimize(fit_objective, init_pars, lr)
return fit_res
def pipeline(x):
model = Model(x)
mle_pars = fit(
model=model,
data=jnp.array([5.0, 5.0]),
init_pars=jnp.array([1.0, 1.1]),
lr=1e-3,
)
return mle_pars
jax.jacrev(pipeline)(jnp.asarray(0.5))
# >> JaxStackTraceBeforeTransformation
(this is jaxopt==0.6
)
Another thing to note: the jaxpr
tracing induced by closure_convert
seems to really fill up the cache, which made this quite problematic in practice (I had to use @patrick-kidger's hack from this JAX issue). Just a health warning for anyone else interested in this type of solution!
I can't see an immediate way, but if we could cast this example into the form you referenced above with the decomposed derivatives, that would be the best way to get around this issue (i.e. avoid closure_convert
altogether).
I explored this a bit, and my particular workflow here was made possible if one makes the Model
class a Pytree, which allows me to feed in the model as an explicit argument to the objective function while keeping jit across the optimization procedure. I think this also means that the relevant parameters for grad are no longer closed over, since the Pytree contains that information.
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jaxopt
import optax
from simple_pytree import Pytree
# dummy model for test purposes
class Model(Pytree):
x: jax.Array
def __init__(self, x) -> None:
self.x = x
def logpdf(self, pars, data):
return jsp.stats.norm.logpdf(data, loc=pars*self.x, scale=1.0).sum()
@jax.jit
def pipeline(param_for_grad):
data=jnp.array([5.0, 5.0])
init_pars=jnp.array([1.0, 1.1])
lr=1e-3
model = Model(param_for_grad)
def fit(pars, model, data):
def fit_objective(pars, model, data):
return -model.logpdf(pars, data)
solver = jaxopt.OptaxSolver(
fun=fit_objective, opt=optax.adam(lr), implicit_diff=True, maxiter=5000
)
return solver.run(init_pars, model=model, data=data)[0]
return fit(init_pars, model, data)
jax.jacrev(pipeline)(jnp.asarray(0.5))
# > Array([-1.33830826e+01, 7.10542736e-15], dtype=float64, weak_type=True)
Don't know if there's another potential issue above that i'm smearing over with this approach, but it works without closure_convert
! It may be hard to coerce a complicated model into a Pytree, but that's possibly something for us to worry more about.
I think you're doing the right thing by making the model a PyTree, i.e. I don't think you're smearing over any issue.
This is the same approach Equinox uses ubiquitously, and this handles all the complexity of Diffrax just fine!