jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.numpy.digitize doesn't work with shape polymorphism

Open tchatow opened this issue 1 year ago • 6 comments

Description

Error when tracing with shape polymorphism in jax.numpy.digitize

  File ".../lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 7548, in searchsorted
    dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64
                     ^^^^^^
TypeError: '_DimExpr' object cannot be interpreted as an integer

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  2.0.0

tchatow avatar Jul 17 '24 14:07 tchatow

Thanks for reporting this! Any chance you can share a repro?

mattjj avatar Jul 17 '24 16:07 mattjj

Here's a simple example

N, = jax.export.symbolic_shape("N")
f = jax.export.export(jax.jit(jnp.digitize))

shape0 = jax.ShapeDtypeStruct((10,), jnp.int32)
shape1 = jax.ShapeDtypeStruct((N,), jnp.int32)
f(shape0, shape1)

tchatow avatar Jul 17 '24 19:07 tchatow

This is trickier than I thought; if we fix that local issue (e.g. by replacing len(a) with a.shape[0]) then we get a downstream issue that the searchsorted implementations rely on static sizes (or at least size bounds).

@jakevdp any ideas?

mattjj avatar Jul 20 '24 02:07 mattjj

Not sure... it's still not clear to me to what extent we should expect shape polymorphism to be supported in JAX APIs. Do we have those goals documented anywhere? It's pretty incomplete at the moment: if we opened issues like this one for every numpy API that doesn't support shape polymophism, we wouldn't have time to do anything else 😀

jakevdp avatar Jul 20 '24 14:07 jakevdp

@jakevdp I also observe np.linalg.inv polymorphism works in 0.4.31 but regressed in main. Would this be a bug or feature request? (Regression in 7b415834145df96f0c80e5cf3cb34cfe796f85b6)

tchatow avatar Sep 09 '24 19:09 tchatow

I can take this on. The shape polymorphism support for the JAX APIs is (always) work in progress, but it is pretty far along. I now fix issues as they pop up.

gnecula avatar Sep 17 '24 08:09 gnecula