Auto check `jit=True` if solver is `vmapped`?
I'm not sure whether this is possible, but I found out quite awkwardly that vmapping solvers strictly requires the optimized function to be staged.
For example, this code works fine:
import jax
import jax.numpy as jnp
import jaxopt
def fun(x, y):
return jnp.square(x - y) + y
solver = jaxopt.LBFGS(fun=fun, maxiter=5, jit=True)
params, state = solver.run(jax.device_put(0.1), y=jax.device_put(2.1))
batch_params, batch_state = jax.vmap(lambda a, k: solver.run(a, **k), in_axes=(0, None))(
jnp.repeat(0.1, 12), {'y': jax.device_put(2.1)})
# No errors raised
But, this code produces a concretization error due to jit=False in the solver.
import jax
import jax.numpy as jnp
import jaxopt
def fun(x, y):
return jnp.square(x - y) + y
solver = jaxopt.LBFGS(fun=fun, maxiter=5, jit=False)
params, state = solver.run(jax.device_put(0.1), y=jax.device_put(2.1)) # Works fine
batch_params, batch_state = jax.vmap(lambda a, k: solver.run(a, **k), in_axes=(0, None))(
jnp.repeat(0.1, 12), {'y': jax.device_put(2.1)}) # ConcretizationTypeError due to vmap
The exception is raised due to the cond-fun in the while-loop, so upon inspection I completely understand why this happens. The argument specification to the solver though kind of made me believe that I had the option to jit or not to jit, but with a vmap you have to jit...
Could this perhaps be inferred automatically? Or is this unfortunately just a sharp-edge of Jax, in which case I believe it would be useful to make a note of this in the documentation for the jit argument in the Solver classes.
Hi @joeryjoery. Indeed, jit=True is necessary for a solver to be "vmappable". We can definitely improve the documentation about this but I am not sure if it's possible to detect this and raise an error message. @froystig will definitely know more.