jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

Using `vmap` for root finding with a vector of parameters

Open Justin-Tan opened this issue 2 years ago • 8 comments

Hi devs,

I'm trying to vmap over an root finding procedure involving vector-valued parameters. I gather this is currently not possible with jaxopt.ScipyRootFinding because this is a wrapper around scipy and throws up TracerArrayConversionError. Do you have any suggestions for being able to batch such an optimization procedure? Is the only recourse to write a custom root-finding routine in Jax? I should note I don't necessarily need gradients of the root finder, just the batching ability.

Justin-Tan avatar Aug 03 '23 04:08 Justin-Tan

Hello! Could you send a minimal code example? I'm happy to investigate that. Which method are you using?

vroulet avatar Aug 03 '23 07:08 vroulet

This comes from the fact that ScipyMinimize and ScipyRootFinding are not written in pure JAX. We need to wrap the call to scipy.optimize.minimize with a jax.lax.pure_callback (see #372) as was done for instance here. This will make ScipyMinimize and ScipyRootFinding work with vmap and jit.

mblondel avatar Aug 03 '23 09:08 mblondel

Hello! Could you send a minimal code example? I'm happy to investigate that. Which method are you using?

Hi, I'm running into two different errors, with and w/o vmap. Hopefully they're related. Here's a MWE:

  • without vmap
import jax, jaxopt
import jax.numpy as jnp
import numpy as np

def opt_fun(x, b):
    return jnp.linalg.norm(x - b)

x = np.random.randn(3)
b = np.random.randn(3)

root_finder  = jaxopt.ScipyRootFinding('hybr', optimality_fun=opt_fun)
root_finder.run(x,b)
# ValueError: cannot reshape array of size 6 into shape (6,6)
    • with vmap
x = np.random.randn(2,3)
b = np.random.randn(2,3)

root_finder  = jaxopt.ScipyRootFinding('hybr', optimality_fun=opt_fun)
jax,vmap(root_finder.run, in_axes=(0,0))(x, b)
# TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[3]...

This comes from the fact that ScipyMinimize and ScipyRootFinding are not written in pure JAX. We need to wrap the call to scipy.optimize.minimize with a jax.lax.pure_callback (see #372) as was done for instance here. This will make ScipyMinimize and ScipyRootFinding work with vmap and jit.

Thanks, could you elaborate slightly here? Is jax.pure_callback supposed to wrap the optimality function? Doing so seemss to give me an Xlaruntime error.

Justin-Tan avatar Aug 04 '23 05:08 Justin-Tan

Thanks! For the first error, it's just that to find a root of a function, this function must be from your input space to your input space. Right now you used a function from your input space to a real number.

Typically, you should have something of the form

import jax
import jax.numpy as jnp
import numpy as np

def opt_fun(x, b):
    return x-b

x = np.random.randn(3)
b = np.random.randn(3)

root_finder = ScipyRootFinding('hybr', optimality_fun=opt_fun)
root_finder.run(x, b)

vroulet avatar Aug 04 '23 08:08 vroulet

@Justin-Tan The modifications are on our side. We need to wrap the call to scipy.optimize.minimize or scipy.optimize.root with a pure_callback.

mblondel avatar Aug 04 '23 09:08 mblondel

Thanks both, any pointers on how I can batch a problem where I have multiple root finding problems involving a vector of parameters in Jax? I've tried treating it as a minimization problem and vmaping using the standard jax.scipy.optimize.minimize and the unconstrained minimization solvers in jaxopt but it's still rather slow.

Justin-Tan avatar Aug 04 '23 11:08 Justin-Tan

I managed to get it to work, but it's still rather slow. I'll try 'pmap'ing it over cores to see if that helps. I couldn't get it to work with jaxopt's 'ScipyWrapper' - https://jaxopt.github.io/stable/_modules/jaxopt/_src/scipy_wrappers.html#ScipyRootFinding.run because of issues with computation of the Jacobian via implicit diff, I think.

from jax import config
config.update("jax_enable_x64", True)
import jax, jaxopt
from jax import vmap, jit
import jax.numpy as jnp
import numpy as np

import scipy.optimize as so

def opt_fun(x, b):
    return x-b
jac_opt_fun = jax.jacrev(opt_fun)

def scipy_fun(x_onp, *args, **kwargs):
    # del scipy_args
    x_jnp = jnp.asarray(x_onp)
    value_jnp = opt_fun(x_jnp, *args, **kwargs)
    jacs_jnp = jac_opt_fun(x_jnp, *args, **kwargs)
    return np.asarray(value_jnp), np.asarray(jacs_jnp)

def scipy_opt_root(init_params, opt_args):
    res = so.root(scipy_fun,
                    init_params,
                    args=opt_args,
                    method=method,
                    jac=True)
    return res.x

@jit
def find_root(init_params, opt_args):
    shape_dtype = jax.ShapeDtypeStruct(shape=init_params.shape, 
                                    dtype=init_params.dtype)
    res = jax.pure_callback(scipy_opt_root, 
                            shape_dtype,
                            init_params,
                            opt_args,
                            vectorized=False)
    return res

method = 'hybr'
x = np.random.randn(2,3)
b = np.random.randn(2,3)
vmap(find_root)(x, b)

Justin-Tan avatar Aug 05 '23 23:08 Justin-Tan

Hello Justin! Yes, the issue is that pure_callback would require the whole code inside to be in a numpy format. But since the function is supposed to be in jax, we are converting it before calling the scipy function. I tried here https://github.com/google/jaxopt/pull/494 (see sandbox.py file, you probably ran into a similar problem). I don't know yet how to circumvent this and if having our own implementation of a few root-finding algorithms could be a better option.

vroulet avatar Aug 07 '23 16:08 vroulet