jax icon indicating copy to clipboard operation
jax copied to clipboard

MLIRError raised when using jax.lax.shift_right_logical in pallas

Open josipd opened this issue 11 months ago • 0 comments

Description

To reproduce, run the following on a TPU colab

def buggy_kernel(x_ref, z_ref):
  jax.lax.shift_right_logical(x_ref, 4)

def buggy(x: jax.Array):
  return pl.pallas_call(
    matmul_kernel,
    out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), b.dtype),
    grid=(2, 2),
    in_specs=[
      pl.BlockSpec(lambda i, j: (i, 0), (x.shape[0] // 2, x.shape[1])),
    ],
    out_specs=pl.BlockSpec(
      lambda i, j: (i, j), (x.shape[0] // 2, x.shape[1]),
    )
  )(x)
buggy(jnp.zeros((1024, 512), dtype=jnp.int8))

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.3
python: 3.11.7
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

josipd avatar Mar 08 '24 13:03 josipd