jax
jax copied to clipboard
jax.numpy.digitize doesn't work with shape polymorphism
Description
Error when tracing with shape polymorphism in jax.numpy.digitize
File ".../lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 7548, in searchsorted
dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64
^^^^^^
TypeError: '_DimExpr' object cannot be interpreted as an integer
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.30
jaxlib: 0.4.30
numpy: 2.0.0
Thanks for reporting this! Any chance you can share a repro?
Here's a simple example
N, = jax.export.symbolic_shape("N")
f = jax.export.export(jax.jit(jnp.digitize))
shape0 = jax.ShapeDtypeStruct((10,), jnp.int32)
shape1 = jax.ShapeDtypeStruct((N,), jnp.int32)
f(shape0, shape1)
This is trickier than I thought; if we fix that local issue (e.g. by replacing len(a) with a.shape[0]) then we get a downstream issue that the searchsorted implementations rely on static sizes (or at least size bounds).
@jakevdp any ideas?
Not sure... it's still not clear to me to what extent we should expect shape polymorphism to be supported in JAX APIs. Do we have those goals documented anywhere? It's pretty incomplete at the moment: if we opened issues like this one for every numpy API that doesn't support shape polymophism, we wouldn't have time to do anything else 😀
@jakevdp I also observe np.linalg.inv polymorphism works in 0.4.31 but regressed in main. Would this be a bug or feature request? (Regression in 7b415834145df96f0c80e5cf3cb34cfe796f85b6)
I can take this on. The shape polymorphism support for the JAX APIs is (always) work in progress, but it is pretty far along. I now fix issues as they pop up.