Using `vmap` for root finding with a vector of parameters
Hi devs,
I'm trying to vmap over an root finding procedure involving vector-valued parameters. I gather this is currently not possible with jaxopt.ScipyRootFinding because this is a wrapper around scipy and throws up TracerArrayConversionError. Do you have any suggestions for being able to batch such an optimization procedure? Is the only recourse to write a custom root-finding routine in Jax? I should note I don't necessarily need gradients of the root finder, just the batching ability.
Hello! Could you send a minimal code example? I'm happy to investigate that. Which method are you using?
This comes from the fact that ScipyMinimize and ScipyRootFinding are not written in pure JAX. We need to wrap the call to scipy.optimize.minimize with a jax.lax.pure_callback (see #372) as was done for instance here. This will make ScipyMinimize and ScipyRootFinding work with vmap and jit.
Hello! Could you send a minimal code example? I'm happy to investigate that. Which method are you using?
Hi, I'm running into two different errors, with and w/o vmap. Hopefully they're related. Here's a MWE:
- without
vmap
import jax, jaxopt
import jax.numpy as jnp
import numpy as np
def opt_fun(x, b):
return jnp.linalg.norm(x - b)
x = np.random.randn(3)
b = np.random.randn(3)
root_finder = jaxopt.ScipyRootFinding('hybr', optimality_fun=opt_fun)
root_finder.run(x,b)
# ValueError: cannot reshape array of size 6 into shape (6,6)
-
- with
vmap
- with
x = np.random.randn(2,3)
b = np.random.randn(2,3)
root_finder = jaxopt.ScipyRootFinding('hybr', optimality_fun=opt_fun)
jax,vmap(root_finder.run, in_axes=(0,0))(x, b)
# TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[3]...
This comes from the fact that
ScipyMinimizeandScipyRootFindingare not written in pure JAX. We need to wrap the call toscipy.optimize.minimizewith ajax.lax.pure_callback(see #372) as was done for instance here. This will makeScipyMinimizeandScipyRootFindingwork withvmapandjit.
Thanks, could you elaborate slightly here? Is jax.pure_callback supposed to wrap the optimality function? Doing so seemss to give me an Xlaruntime error.
Thanks! For the first error, it's just that to find a root of a function, this function must be from your input space to your input space. Right now you used a function from your input space to a real number.
Typically, you should have something of the form
import jax
import jax.numpy as jnp
import numpy as np
def opt_fun(x, b):
return x-b
x = np.random.randn(3)
b = np.random.randn(3)
root_finder = ScipyRootFinding('hybr', optimality_fun=opt_fun)
root_finder.run(x, b)
@Justin-Tan The modifications are on our side. We need to wrap the call to scipy.optimize.minimize or scipy.optimize.root with a pure_callback.
Thanks both, any pointers on how I can batch a problem where I have multiple root finding problems involving a vector of parameters in Jax? I've tried treating it as a minimization problem and vmaping using the standard jax.scipy.optimize.minimize and the unconstrained minimization solvers in jaxopt but it's still rather slow.
I managed to get it to work, but it's still rather slow. I'll try 'pmap'ing it over cores to see if that helps. I couldn't get it to work with jaxopt's 'ScipyWrapper' - https://jaxopt.github.io/stable/_modules/jaxopt/_src/scipy_wrappers.html#ScipyRootFinding.run because of issues with computation of the Jacobian via implicit diff, I think.
from jax import config
config.update("jax_enable_x64", True)
import jax, jaxopt
from jax import vmap, jit
import jax.numpy as jnp
import numpy as np
import scipy.optimize as so
def opt_fun(x, b):
return x-b
jac_opt_fun = jax.jacrev(opt_fun)
def scipy_fun(x_onp, *args, **kwargs):
# del scipy_args
x_jnp = jnp.asarray(x_onp)
value_jnp = opt_fun(x_jnp, *args, **kwargs)
jacs_jnp = jac_opt_fun(x_jnp, *args, **kwargs)
return np.asarray(value_jnp), np.asarray(jacs_jnp)
def scipy_opt_root(init_params, opt_args):
res = so.root(scipy_fun,
init_params,
args=opt_args,
method=method,
jac=True)
return res.x
@jit
def find_root(init_params, opt_args):
shape_dtype = jax.ShapeDtypeStruct(shape=init_params.shape,
dtype=init_params.dtype)
res = jax.pure_callback(scipy_opt_root,
shape_dtype,
init_params,
opt_args,
vectorized=False)
return res
method = 'hybr'
x = np.random.randn(2,3)
b = np.random.randn(2,3)
vmap(find_root)(x, b)
Hello Justin! Yes, the issue is that pure_callback would require the whole code inside to be in a numpy format. But since the function is supposed to be in jax, we are converting it before calling the scipy function. I tried here https://github.com/google/jaxopt/pull/494 (see sandbox.py file, you probably ran into a similar problem). I don't know yet how to circumvent this and if having our own implementation of a few root-finding algorithms could be a better option.