fast-soft-sort
fast-soft-sort copied to clipboard
Unable to jit jax ops
Thanks for the work on this! Here's an example:
import jax.numpy as jnp
from fast_soft_sort.jax_ops import soft_rank
from jax import grad, jit
@jit
def f1(x):
x = x.reshape(1, 3)
y = soft_rank(x)[0]
return y.mean()
x = jnp.array([1.0, 2.0, 3.0])
f2 = grad(f1)
f2(x)
Traceback (most recent call last):
File "test.py", line 15, in <module>
f2(x)
File "test.py", line 9, in f1
y = soft_rank(x)[0]
File "/home/patrick/.pyenv/versions/3.8.5/lib/python3.8/site-packages/fast_soft_sort/jax_ops.py", line 80, in soft_rank
return jnp.vstack([func(val) for val in values])
File "/home/patrick/.pyenv/versions/3.8.5/lib/python3.8/site-packages/fast_soft_sort/jax_ops.py", line 80, in <listcomp>
return jnp.vstack([func(val) for val in values])
File "/home/patrick/.pyenv/versions/3.8.5/lib/python3.8/site-packages/fast_soft_sort/jax_ops.py", line 35, in _func_fwd
values = np.array(values)
jax._src.traceback_util.FilteredStackTrace: Exception: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[3])>wit$
<DynamicJaxprTrace(level=0/1)>.
This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` toge$
her with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be tha$
the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`.
The current implementation is just a wrapper around the numba-based code, so unfortunately it can't be used together with jit...
There is on-going work on a Numba / JAX bridge, which would enable to use Numba code from within a jitted JAX function but it may take some time to land in JAX master.
CC @josipd
Sorry to jump in but that's actually easier than you seem to think. See the below code as an example
from functools import partial
import numba as nb
from jax import jit, ShapeDtypeStruct
from jax.experimental import host_callback
@nb.jit
def some_numba_stuff(x):
return x
@partial(jit, backend="gpu")
def some_jax_stuff(x):
y = host_callback.call(some_numba_stuff, x, result_shape=ShapeDtypeStruct(x.shape, x.dtype))
z = 2 * y
return y
print(some_jax_stuff(5.))
Because host_callback is still experimental I don't expect @mblondel would want it to live in his code, but depending on what @patrickpei is up to, that probably would do the trick. But then I guess you would have to work with your own fork and modify the numpy ops wrapper inplace. Decisions decisions :)
Or... you just replace the numba implementation of isotonic regression with a jax one. Can't use PAV, solutions are not absolutely perfect, but it works
import jax
import jax.numpy as jnp
import jaxopt as jo
from jaxopt import ProjectedGradient
from jaxopt.projection import projection_non_negative
def projection_non_negative_after0(x: jnp.array, hyperparams=None) -> jnp.array:
return x.at[1:].set(jnp.where(x[1:] < 0.0, 0.0, x[1:]))
def constrain_param(param):
return param.at[1:].set(param[0] - jnp.cumsum(param[1:]))
def isotonic_opt_l2(y: jnp.array) -> jnp.array:
"""Solves an isotonic regression problem with L2 loss using projected gradient descent.
Formally, it solves argmin_{v_1 >= ... >= v_n} 0.5 ||v - y||^2.
Args:
y: input to isotonic regression, a 1d-array.
"""
def loss_fn(param):
return jnp.sum((y - constrain_param(param)) ** 2) / 2
solver = jo.ProjectedGradient(
fun=loss_fn, maxiter=10000, projection=projection_non_negative_after0
)
sol = solver.run(y.at[1:].set(0.0)).params
return sol.at[1:].set(y[0] - jnp.cumsum(sol[1:]))
def isotonic_opt_kl(y: jnp.array, w: jnp.array) -> jnp.array:
"""Solves an isotonic regression problem with KL divergence using projected gradient descent.
Formally, it solves argmin_{v_1 >= ... >= v_n} <e^{y-v}, 1> + <e^w, v>.
Args:
y: input to isotonic optimization, a 1d-array.
w: input to isotonic optimization, a 1d-array.
"""
def loss_fn(param):
constr = constrain_param(param)
return jnp.sum(jnp.exp(y - constr)) + jnp.sum(jnp.exp(w) * constr)
solver = jo.ProjectedGradient(
fun=loss_fn, maxiter=10000, projection=projection_non_negative_after0
)
sol = solver.run(y.at[1:].set(0.0)).params
return sol.at[1:].set(y[0] - jnp.cumsum(sol[1:]))
y = jnp.array([10.0, 7.0, 6.5, 6.8, 7.0, 8.0, 9.0])
print(isotonic_opt_l2(y))
print(isotonic_opt_kl(y, jnp.ones_like(y)))
JIT this without a problem. I'm not sure about correctness of the KL solution...
Have you solved this? It just does not work within Jax !