jax icon indicating copy to clipboard operation
jax copied to clipboard

`jnp.digitize` does not work in corner case with no bins

Open Gattocrucco opened this issue 1 year ago • 2 comments

Description

np.digitize(x, []) is equivalent to zeros_like(x, int). jnp.digitize blows up instead. The problem looks like a simple typo in the code.

from jax import numpy as jnp
jnp.digitize(jnp.arange(4), jnp.empty(0))
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[123], line 2
      1 from jax import numpy as jnp
----> 2 jnp.digitize(jnp.arange(4), jnp.empty(0))

    [... skipping hidden 12 frame]

File ~/Documents/Programmi/micromamba/envs/bartz/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:5383, in digitize(x, bins, right)
   5381   raise ValueError(f"digitize: bins must be a 1-dimensional array; got {bins=}")
   5382 if bins_arr.shape[0] == 0:
-> 5383   return zeros(x, dtype=dtypes.canonicalize_dtype(int_))
   5384 side = 'right' if not right else 'left'
   5385 return where(
   5386   bins_arr[-1] >= bins_arr[0],
   5387   searchsorted(bins_arr, x, side=side),
   5388   len(bins_arr) - searchsorted(bins_arr[::-1], x, side=side)
   5389 )

File ~/Documents/Programmi/micromamba/envs/bartz/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:2332, in zeros(shape, dtype, device)
   2330 if (m := _check_forgot_shape_tuple("zeros", shape, dtype)): raise TypeError(m)
   2331 dtypes.check_user_dtype_supported(dtype, "zeros")
-> 2332 shape = canonicalize_shape(shape)
   2333 return lax.full(shape, 0, _jnp_dtype(dtype), sharding=_normalize_to_sharding(device))

File ~/Documents/Programmi/micromamba/envs/bartz/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:86, in canonicalize_shape(shape, context)
     84   return core.canonicalize_shape((shape,), context)  # type: ignore
     85 else:
---> 86   return core.canonicalize_shape(shape, context)

File ~/Documents/Programmi/micromamba/envs/bartz/lib/python3.12/site-packages/jax/_src/core.py:2117, in canonicalize_shape(shape, context)
   2115 except TypeError:
   2116   pass
-> 2117 raise _invalid_shape_error(shape, context)

TypeError: Shapes must be 1D sequences of concrete values of integer type, got Traced<ShapedArray(int32[4])>with<DynamicJaxprTrace(level=1/0)>.
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function digitize at /Users/giacomo/Documents/Programmi/micromamba/envs/bartz/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:5374 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
The error occurred while tracing the function digitize at /Users/giacomo/Documents/Programmi/micromamba/envs/bartz/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:5374 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
The error occurred while tracing the function digitize at /Users/giacomo/Documents/Programmi/micromamba/envs/bartz/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:5374 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
The error occurred while tracing the function digitize at /Users/giacomo/Documents/Programmi/micromamba/envs/bartz/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:5374 for jit. This concrete value was not available in Python because it depends on the value of the argument x.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.25
jaxlib: 0.4.25
numpy:  1.26.4
python: 3.12.2 | packaged by conda-forge | (main, Feb 16 2024, 20:54:21) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', release='23.4.0', version='Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:43 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6000', machine='arm64')

Gattocrucco avatar Mar 14 '24 21:03 Gattocrucco

Good catch! Are you interested in putting together a PR with the fix? If not one of the team members can take care of it. Thanks!

jakevdp avatar Mar 14 '24 21:03 jakevdp

ok, i'll do the pr

Gattocrucco avatar Mar 14 '24 21:03 Gattocrucco