jax
jax copied to clipboard
jax-metal: segmentation fault inside `jax.lax.while_loop`
Description
import jax
import jax.numpy as jnp
def f(x):
def cond(carry):
i, x, acc = carry
return i < x.shape[0]
def body(carry):
i, x, acc = carry
return (i + 1, x, acc + x[i])
i = jnp.array(0)
acc = jnp.array(0)
return jax.lax.while_loop(cond, body, (i, x, acc))
x = jnp.array([1, 2, 3])
# Print lowered HLO
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
HLO
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3xi32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<i32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
%0 = stablehlo.constant dense<0> : tensor<i32>
%1 = stablehlo.constant dense<1> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %arg0, %iterArg_1 = %1) : tensor<i32>, tensor<3xi32>, tensor<i32>
cond {
%3 = stablehlo.constant dense<3> : tensor<i32>
%4 = stablehlo.compare LT, %iterArg, %3, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %4 : tensor<i1>
} do {
%3 = stablehlo.constant dense<1> : tensor<i32>
%4 = stablehlo.add %iterArg, %3 : tensor<i32>
%5 = stablehlo.constant dense<0> : tensor<i32>
%6 = stablehlo.compare LT, %iterArg, %5, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%7 = stablehlo.convert %iterArg : tensor<i32>
%8 = stablehlo.constant dense<3> : tensor<i32>
%9 = stablehlo.add %7, %8 : tensor<i32>
%10 = stablehlo.select %6, %9, %iterArg : tensor<i1>, tensor<i32>
%11 = stablehlo.dynamic_slice %iterArg_0, %10, sizes = [1] : (tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
%12 = stablehlo.reshape %11 : (tensor<1xi32>) -> tensor<i32>
%13 = stablehlo.convert %iterArg_1 : tensor<i32>
%14 = stablehlo.add %13, %12 : tensor<i32>
stablehlo.return %4, %iterArg_0, %14 : tensor<i32>, tensor<3xi32>, tensor<i32>
}
return %2#0, %2#1, %2#2 : tensor<i32>, tensor<3xi32>, tensor<i32>
}
}
This above loop computes the sum of tensor elements. Running the code results in a segmentation fault.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')
jax-metal 0.0.7