jax icon indicating copy to clipboard operation
jax copied to clipboard

jax.numpy.digitize doesn't work with shape polymorphism

Open tchatow opened this issue 7 months ago • 6 comments

Description

Error when tracing with shape polymorphism in jax.numpy.digitize

  File ".../lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py", line 7548, in searchsorted
    dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64
                     ^^^^^^
TypeError: '_DimExpr' object cannot be interpreted as an integer

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.30
jaxlib: 0.4.30
numpy:  2.0.0

tchatow avatar Jul 17 '24 14:07 tchatow