jax
jax copied to clipboard
[Pallas TPU] Mosaic legalisation error caused by `select(b, 1, 0)` being eliminated
Description
import functools
import jax
from jax.experimental import pallas as pl
from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
@functools.partial(
pl.pallas_call,
in_specs=(
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
),
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
out_shape=jax.ShapeDtypeStruct((1,), jnp.bool_),
# interpret=True,
debug=True,
)
def kernel(x_ref, y_ref, o_ref):
a = jnp.int32(1)
b = jnp.int32(0)
x = jnp.where(x_ref[0], a, b)
y = jnp.where(y_ref[0], a, b)
o_ref[0] = x == y
def main():
x = jnp.array([True], dtype=jnp.bool)
y = jnp.array([False], dtype=jnp.bool)
out = kernel(x, y)
print(out)
if __name__ == '__main__':
main()
Error:
The kernel jaxpr for pallas_call kernel at /home/ayx/jax/2.py:7:
let _where = { lambda ; a:bool[] b:i32[] c:i32[]. let
d:i32[] = select_n a c b
in (d,) } in
{ lambda ; e:MemRef<smem>{bool[1]} f:MemRef<smem>{bool[1]} g:MemRef<smem>{bool[1]}. let
h:bool[] <- e[0]
i:i32[] = pjit[name=_where jaxpr=_where] h 1 0
j:bool[] <- f[0]
k:i32[] = pjit[name=_where jaxpr=_where] j 1 0
l:bool[] = eq i k
g[0] <- l
in () }
The Mosaic module for pallas_call kernel at /home/ayx/jax/2.py:7:
module @kernel {
func.func @main(%arg0: memref<1xi32, #tpu.memory_space<smem>>, %arg1: memref<1xi32, #tpu.memory_space<smem>>, %arg2: memref<1xi32, #tpu.memory_space<smem>>) attributes {dimension_semantics = [], scalar_prefetch = 0 : i64, scratch_operands = 0 : i64} {
%c0 = arith.constant 0 : index
%0 = memref.load %arg0[%c0] : memref<1xi32, #tpu.memory_space<smem>>
%c0_i32 = arith.constant 0 : i32
%1 = arith.cmpi ne, %0, %c0_i32 : i32
%c1_i32 = arith.constant 1 : i32
%c0_i32_0 = arith.constant 0 : i32
%2 = arith.select %1, %c1_i32, %c0_i32_0 : i32
%c0_1 = arith.constant 0 : index
%3 = memref.load %arg1[%c0_1] : memref<1xi32, #tpu.memory_space<smem>>
%c0_i32_2 = arith.constant 0 : i32
%4 = arith.cmpi ne, %3, %c0_i32_2 : i32
%c1_i32_3 = arith.constant 1 : i32
%c0_i32_4 = arith.constant 0 : i32
%5 = arith.select %4, %c1_i32_3, %c0_i32_4 : i32
%6 = arith.cmpi eq, %2, %5 : i32
%c0_5 = arith.constant 0 : index
%7 = memref.load %arg2[%c0_5] : memref<1xi32, #tpu.memory_space<smem>>
%c0_i32_6 = arith.constant 0 : i32
%8 = arith.cmpi ne, %7, %c0_i32_6 : i32
%9 = arith.extui %6 : i1 to i32
memref.store %9, %arg2[%c0_5] : memref<1xi32, #tpu.memory_space<smem>>
return
}
}
Traceback (most recent call last):
File "/home/ayx/jax/2.py", line 32, in <module>
main()
~~~~^^
File "/home/ayx/jax/2.py", line 28, in main
out = kernel(x, y)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: failed to legalize operation 'arith.cmpi'
at location: loc(fused["/eq"(callsite("kernel"("/home/ayx/jax/2.py":23:15) at callsite("main"("/home/ayx/jax/2.py":28:10) at "<module>"("/home/ayx/jax/2.py":32:4)))), "/select_n"(callsite("kernel"("/home/ayx/jax/2.py":21:8) at callsite("main"("/home/ayx/jax/2.py":28:10) at "<module>"("/home/ayx/jax/2.py":32:4))))])
The MLIR operation involved:
%21 = "arith.cmpi"(%16, %20) <{predicate = 0 : i64}> : (i1, i1) -> i1
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Note that when changing the value of a
a = jnp.int32(1)
to other numbers, such as:
a = jnp.int32(2)
The kernel will run successfully.
It seems that the error message
The MLIR operation involved:
%21 = "arith.cmpi"(%16, %20) <{predicate = 0 : i64}> : (i1, i1) -> i1
suggests that select(b, 1, 0) is eliminated. This is because that if the select operation is effective, the type signature should become
(i32, i32) -> i1
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.35.dev20241017+0519db15a
jaxlib: 0.4.34
numpy: 2.1.2
python: 3.13.0 (main, Oct 8 2024, 01:04:00) [Clang 18.1.8 ]
device info: TPU v5 lite-8, 8 local devices"
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-ab2ce832-w-0', release='5.19.0-1027-gcp', version='#29~22.04.1-Ubuntu SMP Thu Jun 22 05:13:17 UTC 2023', machine='x86_64')