jax
jax copied to clipboard
MLIRError raised when using jax.lax.shift_right_logical in pallas
Description
To reproduce, run the following on a TPU colab
def buggy_kernel(x_ref, z_ref):
jax.lax.shift_right_logical(x_ref, 4)
def buggy(x: jax.Array):
return pl.pallas_call(
matmul_kernel,
out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), b.dtype),
grid=(2, 2),
in_specs=[
pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),
],
out_specs=pl.BlockSpec(
lambda i, j: (i, j), (x.shape[0] // 2, x.shape[1]),
)
)(x)
buggy(jnp.zeros((1024, 512), dtype=jnp.int8))
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.3
python: 3.11.7
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]