jaxopt
jaxopt copied to clipboard
Errors with `ScipyBoundedMinimize`
I tried the following:
from jaxopt import ScipyBoundedMinimize
solver = ScipyBoundedMinimize(fun=state_action_value_jax, method="l-bfgs-b")
def T_jax(v, model):
def update_v(carry, y):
b = jnp.array((1e-5, y))
result = solver.run(y, bounds=b, data=(y, v, model)).params
return carry + 1, (result.x, -result.fun, result.success)
_, v_values = jax.lax.scan(update_v, 0, model.grid)
return v_values
This raises the following error while calling T_jax
[/usr/local/lib/python3.8/dist-packages/jaxopt/_src/scipy_wrappers.py](https://localhost:8080/#) in jnp_to_onp(x_jnp, dtype)
116 determined by NumPy's casting rules for the concatenate method.
117 """
--> 118 x_onp = [onp.asarray(leaf, dtype).reshape(-1)
119 for leaf in tree_util.tree_leaves(x_jnp)]
120 # NOTE(fllinares): return value must *not* be read-only, I believe.
[/usr/local/lib/python3.8/dist-packages/jaxopt/_src/scipy_wrappers.py](https://localhost:8080/#) in <listcomp>(.0)
116 determined by NumPy's casting rules for the concatenate method.
117 """
--> 118 x_onp = [onp.asarray(leaf, dtype).reshape(-1)
119 for leaf in tree_util.tree_leaves(x_jnp)]
120 # NOTE(fllinares): return value must *not* be read-only, I believe.
TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/1)>
Also, this line raised an error initially: https://github.com/google/jaxopt/blob/a51d5ed87914a6e5c9aac698d206cf625b7e12ce/jaxopt/_src/scipy_wrappers.py#L303 which I fixed locally.
The error was:
AttributeError: module 'scipy' has no attribute 'optimize'
Looks like you're trying to call ScipyMinimize from a jitted function, which is not currently supported. When #372 is done, it should work.