jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

ScipyBoundedMinimize & jit => error

Open jecampagne opened this issue 3 years ago • 2 comments

Hello, Below I give a snippet that crashes when jit is activated on a function that use ScipyBoundedMinimize I am using jax 0.3.5 jaxopt 0.3.1

Here is the snippet

t=np.array([-1.26027397, -0.79178082, -0.21643836,  0.13424658, 0.4       ,
        0.63013699,  0.87945205,  1.28219178,  1.70410959, 2.17260274,
        2.69041096,  3.18356164,  3.68767123,  4.18082192, 4.69863014,
        5.17808219])

R=np.array([17.69069002, 17.89783266, 18.03628099, 17.32367147, 16.54036741,
       15.60960639, 15.32250416, 14.78162801, 14.69544146, 14.84802999,
       14.52303434, 14.30856416, 13.54537271, 12.94627482, 13.5084652 ,
       13.69047547])
@jit
def func(p,t):
    return jnp.where(t<=0,p[0]+p[1]*t,p[0]+p[1]*t-p[2]*(1-jnp.exp(-t/p[3])))
    
@jit
def lik(p,t,R):
    resid = func(p,t)-R
    return 0.5*jnp.sum(resid ** 2) 

def get_infos(res, model, t,R):
    params = res.params
    fun_min= res.state.fun_val
    jacob_min=jax.jacfwd(model)(params, t,R)
    inv_hessian_min=jax.scipy.linalg.inv(jax.hessian(model)(params, t,R))
    return params,fun_min,jacob_min,inv_hessian_min

@jit
def test():
    lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik, method="SLSQP")
    init_params = jnp.array([18.,2.,10.,1.])
    res = lbfgsb.run(init_params, bounds=([-jnp.inf,0,-jnp.inf,-jnp.inf],[jnp.inf]*4), 
                 t=t, R=R)
    return res

res = test()
res 

Then if I comment the jit-decorator of the test function I get what I am expecting

OptStep(params=DeviceArray([1.77792760e+01, 1.05476761e-16, 4.19824962e+00,
             1.17624411e+00], dtype=float64), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.93453007, dtype=float64, weak_type=True), success=True, status=0, iter_num=17))

Now, If I activate the jit-decorator, I get this error

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Input In [12], in <cell line: 1>()
----> 1 res = test()
      2 res

    [... skipping hidden 14 frame]

Input In [11], in test()
      3 lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik, method="SLSQP")
      4 init_params = jnp.array([18.,2.,10.,1.])
----> 5 res = lbfgsb.run(init_params, bounds=([-jnp.inf,0,-jnp.inf,-jnp.inf],[jnp.inf]*4), 
      6              t=t, R=R)
      7 return res

File /jaxOptim/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py:251, in _custom_root.<locals>.wrapped_solver_fun(*args, **kwargs)
    249 args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
    250 keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)

    [... skipping hidden 6 frame]

File jaxOptim/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py:207, in _custom_root.<locals>.make_custom_vjp_solver_fun.<locals>.solver_fun_flat(*flat_args)
    204 @jax.custom_vjp
    205 def solver_fun_flat(*flat_args):
    206   args, kwargs = _extract_kwargs(kwarg_keys, flat_args)
--> 207   return solver_fun(*args, **kwargs)

File jaxOptim/lib/python3.8/site-packages/jaxopt/_src/scipy_wrappers.py:378, in ScipyBoundedMinimize.run(self, init_params, bounds, *args, **kwargs)
    362 def run(self,
    363         init_params: Any,
    364         bounds: Optional[Any],
    365         *args,
    366         **kwargs) -> base.OptStep:
    367   """Runs the solver.
    368 
    369   Args:
   (...)
    376     (params, info).
    377   """
--> 378   return self._run(init_params, bounds, *args, **kwargs)

File /jaxOptim/lib/python3.8/site-packages/jaxopt/_src/scipy_wrappers.py:287, in ScipyMinimize._run(self, init_params, bounds, *args, **kwargs)
    284   return onp.asarray(value, self.dtype), jnp_to_onp(grads, self.dtype)
    286 if bounds is not None:
--> 287   bounds = osp.optimize.Bounds(lb=jnp_to_onp(bounds[0], self.dtype),
    288                                ub=jnp_to_onp(bounds[1], self.dtype))
    290 res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),
    291                             jac=True,
    292                             bounds=bounds,
    293                             method=self.method,
    294                             options=self.options)
    296 params = tree_util.tree_map(jnp.asarray, onp_to_jnp(res.x))

File /jaxOptim/lib/python3.8/site-packages/jaxopt/_src/scipy_wrappers.py:115, in jnp_to_onp(x_jnp, dtype)
     98 def jnp_to_onp(x_jnp: Any,
     99                dtype: Optional[Any] = onp.float64) -> onp.ndarray:
    100   """Converts JAX PyTree into repr suitable for scipy.optimize.minimize.
    101 
    102   Several of SciPy's optimization routines require inputs and/or outputs to be
   (...)
    113     determined by NumPy's casting rules for the concatenate method.
    114   """
--> 115   x_onp = [onp.asarray(leaf, dtype).reshape(-1)
    116            for leaf in tree_util.tree_leaves(x_jnp)]
    117   # NOTE(fllinares): return value must *not* be read-only, I believe.
    118   return onp.concatenate(x_onp)

File /jaxOptim/lib/python3.8/site-packages/jaxopt/_src/scipy_wrappers.py:115, in <listcomp>(.0)
     98 def jnp_to_onp(x_jnp: Any,
     99                dtype: Optional[Any] = onp.float64) -> onp.ndarray:
    100   """Converts JAX PyTree into repr suitable for scipy.optimize.minimize.
    101 
    102   Several of SciPy's optimization routines require inputs and/or outputs to be
   (...)
    113     determined by NumPy's casting rules for the concatenate method.
    114   """
--> 115   x_onp = [onp.asarray(leaf, dtype).reshape(-1)
    116            for leaf in tree_util.tree_leaves(x_jnp)]
    117   # NOTE(fllinares): return value must *not* be read-only, I believe.
    118   return onp.concatenate(x_onp)

    [... skipping hidden 1 frame]

File /jaxOptim/lib/python3.8/site-packages/jax/_src/errors.py:321, in TracerArrayConversionError.__init__(self, tracer)
    319 def __init__(self, tracer: "core.Tracer"):
    320   super().__init__(
--> 321       "The numpy.ndarray conversion method __array__() was called on "
    322       f"the JAX Tracer object {tracer}{tracer._origin_msg()}")

    [... skipping hidden 1 frame]

File /jaxOptim/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1702, in arg_info_pytree(fn, in_tree, has_kwargs, flat_pos)
   1699 def arg_info_pytree(fn: Callable, in_tree: PyTreeDef, has_kwargs: bool,
   1700                     flat_pos: List[int]) -> str:
   1701   dummy_args = [False] * in_tree.num_leaves
-> 1702   for i in flat_pos: dummy_args[i] = True
   1703   if has_kwargs:
   1704     args, kwargs = tree_unflatten(in_tree, dummy_args)

IndexError: list assignment index out of range

Have you an idea if I can use jit in this case and how? Thanks

jecampagne avatar Apr 30 '22 08:04 jecampagne

Thanks for the report! This is a recurring issue. Because ScipyBoundedMinimize calls to NumPy-code (i.e., non-native JAX code), it cannot be jitted. We would need to either wrap the internal code with a host_callback or raise an error with we detect that the function is being jitted.

mblondel avatar May 02 '22 06:05 mblondel

Hi @mblondel you may be interested in the discussion jax/jit

By the way, what is the difference with the jax.scipy BFGS solver, does it is the bounds?

jecampagne avatar May 10 '22 14:05 jecampagne