jax icon indicating copy to clipboard operation
jax copied to clipboard

[Pallas TPU] Mosaic legalisation error caused by `select(b, 1, 0)` being eliminated

Open ayaka14732 opened this issue 1 year ago • 0 comments

Description

import functools
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp

@functools.partial(
    pl.pallas_call,
    in_specs=(
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
        pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
    ),
    out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
    out_shape=jax.ShapeDtypeStruct((1,), jnp.bool_),
    # interpret=True,
    debug=True,
)
def kernel(x_ref, y_ref, o_ref):
    a = jnp.int32(1)
    b = jnp.int32(0)
    x = jnp.where(x_ref[0], a, b)
    y = jnp.where(y_ref[0], a, b)
    o_ref[0] = x == y

def main():
    x = jnp.array([True], dtype=jnp.bool)
    y = jnp.array([False], dtype=jnp.bool)
    out = kernel(x, y)
    print(out)

if __name__ == '__main__':
    main()

Error:

The kernel jaxpr for pallas_call kernel at /home/ayx/jax/2.py:7:
let _where = { lambda ; a:bool[] b:i32[] c:i32[]. let
    d:i32[] = select_n a c b
  in (d,) } in
{ lambda ; e:MemRef<smem>{bool[1]} f:MemRef<smem>{bool[1]} g:MemRef<smem>{bool[1]}. let
    h:bool[] <- e[0]
    i:i32[] = pjit[name=_where jaxpr=_where] h 1 0
    j:bool[] <- f[0]
    k:i32[] = pjit[name=_where jaxpr=_where] j 1 0
    l:bool[] = eq i k
    g[0] <- l
  in () }

The Mosaic module for pallas_call kernel at /home/ayx/jax/2.py:7:
module @kernel {
  func.func @main(%arg0: memref<1xi32, #tpu.memory_space<smem>>, %arg1: memref<1xi32, #tpu.memory_space<smem>>, %arg2: memref<1xi32, #tpu.memory_space<smem>>) attributes {dimension_semantics = [], scalar_prefetch = 0 : i64, scratch_operands = 0 : i64} {
    %c0 = arith.constant 0 : index
    %0 = memref.load %arg0[%c0] : memref<1xi32, #tpu.memory_space<smem>>
    %c0_i32 = arith.constant 0 : i32
    %1 = arith.cmpi ne, %0, %c0_i32 : i32
    %c1_i32 = arith.constant 1 : i32
    %c0_i32_0 = arith.constant 0 : i32
    %2 = arith.select %1, %c1_i32, %c0_i32_0 : i32
    %c0_1 = arith.constant 0 : index
    %3 = memref.load %arg1[%c0_1] : memref<1xi32, #tpu.memory_space<smem>>
    %c0_i32_2 = arith.constant 0 : i32
    %4 = arith.cmpi ne, %3, %c0_i32_2 : i32
    %c1_i32_3 = arith.constant 1 : i32
    %c0_i32_4 = arith.constant 0 : i32
    %5 = arith.select %4, %c1_i32_3, %c0_i32_4 : i32
    %6 = arith.cmpi eq, %2, %5 : i32
    %c0_5 = arith.constant 0 : index
    %7 = memref.load %arg2[%c0_5] : memref<1xi32, #tpu.memory_space<smem>>
    %c0_i32_6 = arith.constant 0 : i32
    %8 = arith.cmpi ne, %7, %c0_i32_6 : i32
    %9 = arith.extui %6 : i1 to i32
    memref.store %9, %arg2[%c0_5] : memref<1xi32, #tpu.memory_space<smem>>
    return
  }
}

Traceback (most recent call last):
  File "/home/ayx/jax/2.py", line 32, in <module>
    main()
    ~~~~^^
  File "/home/ayx/jax/2.py", line 28, in main
    out = kernel(x, y)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: failed to legalize operation 'arith.cmpi'

at location: loc(fused["/eq"(callsite("kernel"("/home/ayx/jax/2.py":23:15) at callsite("main"("/home/ayx/jax/2.py":28:10) at "<module>"("/home/ayx/jax/2.py":32:4)))), "/select_n"(callsite("kernel"("/home/ayx/jax/2.py":21:8) at callsite("main"("/home/ayx/jax/2.py":28:10) at "<module>"("/home/ayx/jax/2.py":32:4))))])

The MLIR operation involved:
  %21 = "arith.cmpi"(%16, %20) <{predicate = 0 : i64}> : (i1, i1) -> i1

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Note that when changing the value of a

a = jnp.int32(1)

to other numbers, such as:

a = jnp.int32(2)

The kernel will run successfully.

It seems that the error message

The MLIR operation involved:
  %21 = "arith.cmpi"(%16, %20) <{predicate = 0 : i64}> : (i1, i1) -> i1

suggests that select(b, 1, 0) is eliminated. This is because that if the select operation is effective, the type signature should become

(i32, i32) -> i1

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.35.dev20241017+0519db15a
jaxlib: 0.4.34
numpy:  2.1.2
python: 3.13.0 (main, Oct  8 2024, 01:04:00) [Clang 18.1.8 ]
device info: TPU v5 lite-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-ab2ce832-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')

ayaka14732 avatar Oct 17 '24 15:10 ayaka14732