ScipyBoundedMinimize & jit => error
Hello, Below I give a snippet that crashes when jit is activated on a function that use ScipyBoundedMinimize I am using jax 0.3.5 jaxopt 0.3.1
Here is the snippet
t=np.array([-1.26027397, -0.79178082, -0.21643836, 0.13424658, 0.4 ,
0.63013699, 0.87945205, 1.28219178, 1.70410959, 2.17260274,
2.69041096, 3.18356164, 3.68767123, 4.18082192, 4.69863014,
5.17808219])
R=np.array([17.69069002, 17.89783266, 18.03628099, 17.32367147, 16.54036741,
15.60960639, 15.32250416, 14.78162801, 14.69544146, 14.84802999,
14.52303434, 14.30856416, 13.54537271, 12.94627482, 13.5084652 ,
13.69047547])
@jit
def func(p,t):
return jnp.where(t<=0,p[0]+p[1]*t,p[0]+p[1]*t-p[2]*(1-jnp.exp(-t/p[3])))
@jit
def lik(p,t,R):
resid = func(p,t)-R
return 0.5*jnp.sum(resid ** 2)
def get_infos(res, model, t,R):
params = res.params
fun_min= res.state.fun_val
jacob_min=jax.jacfwd(model)(params, t,R)
inv_hessian_min=jax.scipy.linalg.inv(jax.hessian(model)(params, t,R))
return params,fun_min,jacob_min,inv_hessian_min
@jit
def test():
lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik, method="SLSQP")
init_params = jnp.array([18.,2.,10.,1.])
res = lbfgsb.run(init_params, bounds=([-jnp.inf,0,-jnp.inf,-jnp.inf],[jnp.inf]*4),
t=t, R=R)
return res
res = test()
res
Then if I comment the jit-decorator of the test function I get what I am expecting
OptStep(params=DeviceArray([1.77792760e+01, 1.05476761e-16, 4.19824962e+00,
1.17624411e+00], dtype=float64), state=ScipyMinimizeInfo(fun_val=DeviceArray(0.93453007, dtype=float64, weak_type=True), success=True, status=0, iter_num=17))
Now, If I activate the jit-decorator, I get this error
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Input In [12], in <cell line: 1>()
----> 1 res = test()
2 res
[... skipping hidden 14 frame]
Input In [11], in test()
3 lbfgsb = jaxopt.ScipyBoundedMinimize(fun=lik, method="SLSQP")
4 init_params = jnp.array([18.,2.,10.,1.])
----> 5 res = lbfgsb.run(init_params, bounds=([-jnp.inf,0,-jnp.inf,-jnp.inf],[jnp.inf]*4),
6 t=t, R=R)
7 return res
File /jaxOptim/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py:251, in _custom_root.<locals>.wrapped_solver_fun(*args, **kwargs)
249 args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs)
250 keys, vals = list(kwargs.keys()), list(kwargs.values())
--> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals)
[... skipping hidden 6 frame]
File jaxOptim/lib/python3.8/site-packages/jaxopt/_src/implicit_diff.py:207, in _custom_root.<locals>.make_custom_vjp_solver_fun.<locals>.solver_fun_flat(*flat_args)
204 @jax.custom_vjp
205 def solver_fun_flat(*flat_args):
206 args, kwargs = _extract_kwargs(kwarg_keys, flat_args)
--> 207 return solver_fun(*args, **kwargs)
File jaxOptim/lib/python3.8/site-packages/jaxopt/_src/scipy_wrappers.py:378, in ScipyBoundedMinimize.run(self, init_params, bounds, *args, **kwargs)
362 def run(self,
363 init_params: Any,
364 bounds: Optional[Any],
365 *args,
366 **kwargs) -> base.OptStep:
367 """Runs the solver.
368
369 Args:
(...)
376 (params, info).
377 """
--> 378 return self._run(init_params, bounds, *args, **kwargs)
File /jaxOptim/lib/python3.8/site-packages/jaxopt/_src/scipy_wrappers.py:287, in ScipyMinimize._run(self, init_params, bounds, *args, **kwargs)
284 return onp.asarray(value, self.dtype), jnp_to_onp(grads, self.dtype)
286 if bounds is not None:
--> 287 bounds = osp.optimize.Bounds(lb=jnp_to_onp(bounds[0], self.dtype),
288 ub=jnp_to_onp(bounds[1], self.dtype))
290 res = osp.optimize.minimize(scipy_fun, jnp_to_onp(init_params, self.dtype),
291 jac=True,
292 bounds=bounds,
293 method=self.method,
294 options=self.options)
296 params = tree_util.tree_map(jnp.asarray, onp_to_jnp(res.x))
File /jaxOptim/lib/python3.8/site-packages/jaxopt/_src/scipy_wrappers.py:115, in jnp_to_onp(x_jnp, dtype)
98 def jnp_to_onp(x_jnp: Any,
99 dtype: Optional[Any] = onp.float64) -> onp.ndarray:
100 """Converts JAX PyTree into repr suitable for scipy.optimize.minimize.
101
102 Several of SciPy's optimization routines require inputs and/or outputs to be
(...)
113 determined by NumPy's casting rules for the concatenate method.
114 """
--> 115 x_onp = [onp.asarray(leaf, dtype).reshape(-1)
116 for leaf in tree_util.tree_leaves(x_jnp)]
117 # NOTE(fllinares): return value must *not* be read-only, I believe.
118 return onp.concatenate(x_onp)
File /jaxOptim/lib/python3.8/site-packages/jaxopt/_src/scipy_wrappers.py:115, in <listcomp>(.0)
98 def jnp_to_onp(x_jnp: Any,
99 dtype: Optional[Any] = onp.float64) -> onp.ndarray:
100 """Converts JAX PyTree into repr suitable for scipy.optimize.minimize.
101
102 Several of SciPy's optimization routines require inputs and/or outputs to be
(...)
113 determined by NumPy's casting rules for the concatenate method.
114 """
--> 115 x_onp = [onp.asarray(leaf, dtype).reshape(-1)
116 for leaf in tree_util.tree_leaves(x_jnp)]
117 # NOTE(fllinares): return value must *not* be read-only, I believe.
118 return onp.concatenate(x_onp)
[... skipping hidden 1 frame]
File /jaxOptim/lib/python3.8/site-packages/jax/_src/errors.py:321, in TracerArrayConversionError.__init__(self, tracer)
319 def __init__(self, tracer: "core.Tracer"):
320 super().__init__(
--> 321 "The numpy.ndarray conversion method __array__() was called on "
322 f"the JAX Tracer object {tracer}{tracer._origin_msg()}")
[... skipping hidden 1 frame]
File /jaxOptim/lib/python3.8/site-packages/jax/interpreters/partial_eval.py:1702, in arg_info_pytree(fn, in_tree, has_kwargs, flat_pos)
1699 def arg_info_pytree(fn: Callable, in_tree: PyTreeDef, has_kwargs: bool,
1700 flat_pos: List[int]) -> str:
1701 dummy_args = [False] * in_tree.num_leaves
-> 1702 for i in flat_pos: dummy_args[i] = True
1703 if has_kwargs:
1704 args, kwargs = tree_unflatten(in_tree, dummy_args)
IndexError: list assignment index out of range
Have you an idea if I can use jit in this case and how? Thanks
Thanks for the report! This is a recurring issue. Because ScipyBoundedMinimize calls to NumPy-code (i.e., non-native JAX code), it cannot be jitted. We would need to either wrap the internal code with a host_callback or raise an error with we detect that the function is being jitted.
Hi @mblondel you may be interested in the discussion jax/jit
By the way, what is the difference with the jax.scipy BFGS solver, does it is the bounds?