jax
jax copied to clipboard
jax-metal: inconsistent index clipping in jax.lax.dynamic_slice
Description
import jax
import jax.numpy as jnp
def f(x, idx):
return jax.lax.dynamic_slice(x, [idx], [2])
x = jnp.array([1, 2, 3])
idx = jnp.array(2)
# Print lowered HLO
print(jax.jit(f).lower(x, idx).as_text())
print(jax.jit(f)(x, idx))
In the example above idx
is too high and there are not enough elements to slice (start_index + length <= dimension
). In such case, the index should be clipped to 1
, returning [2, 3]
in this case (this is the current behaviour on the CPU platform). jax-metal returns [3, 3]
, but it actually behaves correctly if we replace idx
with the constant 2
(the operation doesn't change).
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')
jax-metal 0.0.7