jax
jax copied to clipboard
jax.numpy.digitize doesn't work with shape polymorphism
Description
Error when tracing with shape polymorphism in jax.numpy.digitize
File ".../lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 7548, in searchsorted
dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64
^^^^^^
TypeError: '_DimExpr' object cannot be interpreted as an integer
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.30
jaxlib: 0.4.30
numpy: 2.0.0