jaxopt icon indicating copy to clipboard operation
jaxopt copied to clipboard

Errors with `ScipyBoundedMinimize`

Open Smit-create opened this issue 2 years ago • 2 comments

I tried the following:

from jaxopt import ScipyBoundedMinimize
solver = ScipyBoundedMinimize(fun=state_action_value_jax, method="l-bfgs-b")

def T_jax(v, model):
    def update_v(carry, y):
        b = jnp.array((1e-5, y))
        result = solver.run(y, bounds=b, data=(y, v, model)).params
        return carry + 1, (result.x, -result.fun, result.success)
    _, v_values = jax.lax.scan(update_v, 0, model.grid) 
    return v_values

This raises the following error while calling T_jax

[/usr/local/lib/python3.8/dist-packages/jaxopt/_src/scipy_wrappers.py](https://localhost:8080/#) in jnp_to_onp(x_jnp, dtype)
    116     determined by NumPy's casting rules for the concatenate method.
    117   """
--> 118   x_onp = [onp.asarray(leaf, dtype).reshape(-1)
    119            for leaf in tree_util.tree_leaves(x_jnp)]
    120   # NOTE(fllinares): return value must *not* be read-only, I believe.

[/usr/local/lib/python3.8/dist-packages/jaxopt/_src/scipy_wrappers.py](https://localhost:8080/#) in <listcomp>(.0)
    116     determined by NumPy's casting rules for the concatenate method.
    117   """
--> 118   x_onp = [onp.asarray(leaf, dtype).reshape(-1)
    119            for leaf in tree_util.tree_leaves(x_jnp)]
    120   # NOTE(fllinares): return value must *not* be read-only, I believe.

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/1)>

Smit-create avatar Feb 20 '23 03:02 Smit-create

Also, this line raised an error initially: https://github.com/google/jaxopt/blob/a51d5ed87914a6e5c9aac698d206cf625b7e12ce/jaxopt/_src/scipy_wrappers.py#L303 which I fixed locally.

The error was:

AttributeError: module 'scipy' has no attribute 'optimize'

Smit-create avatar Feb 20 '23 03:02 Smit-create

Looks like you're trying to call ScipyMinimize from a jitted function, which is not currently supported. When #372 is done, it should work.

mblondel avatar Mar 01 '23 14:03 mblondel