jax icon indicating copy to clipboard operation
jax copied to clipboard

jax-metal: inconsistent behaviour of overflowing update slice in `jax.lax.dynamic_update_slice`

Open jonatanklosko opened this issue 1 month ago • 0 comments

Description

import jax
import jax.numpy as jnp

def f(x, update):
  return jax.lax.dynamic_update_slice(x, update, [1, 1])

x = jnp.array([[1, 2, 3], [4, 5, 6]])
update = jnp.array([[7, 8], [9, 10]])

# Print lowered HLO
print(jax.jit(f).lower(x, update).as_text())
print(jax.jit(f)(x, update))

In the example above, if we place update as specified by the indices, it overflows. jax-metal ignores the overflow and returns [[1, 2, 3], [1, 7, 8]] (updating the second row). Using the CPU platform, jax clips the start indices in such cases (when start_index + slice_dimension > dimension), so it would return [[1, 7, 8], [1, 9, 10]].

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

jonatanklosko avatar May 23 '24 11:05 jonatanklosko