jax
jax copied to clipboard
jax-metal: segmentation fault inside `jax.lax.while_loop`
Description
import jax
import jax.numpy as jnp
def f(x):
def cond(carry):
i, x, acc = carry
return i < x.shape[0]
def body(carry):
i, x, acc = carry
return (i + 1, x, acc + x[i])
i = jnp.array(0)
acc = jnp.array(0)
return jax.lax.while_loop(cond, body, (i, x, acc))
x = jnp.array([1, 2, 3])
# Print lowered HLO
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
HLO
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3xi32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<i32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
%0 = stablehlo.constant dense<0> : tensor<i32>
%1 = stablehlo.constant dense<1> : tensor<i32>
%2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %arg0, %iterArg_1 = %1) : tensor<i32>, tensor<3xi32>, tensor<i32>
cond {
%3 = stablehlo.constant dense<3> : tensor<i32>
%4 = stablehlo.compare LT, %iterArg, %3, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %4 : tensor<i1>
} do {
%3 = stablehlo.constant dense<1> : tensor<i32>
%4 = stablehlo.add %iterArg, %3 : tensor<i32>
%5 = stablehlo.constant dense<0> : tensor<i32>
%6 = stablehlo.compare LT, %iterArg, %5, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%7 = stablehlo.convert %iterArg : tensor<i32>
%8 = stablehlo.constant dense<3> : tensor<i32>
%9 = stablehlo.add %7, %8 : tensor<i32>
%10 = stablehlo.select %6, %9, %iterArg : tensor<i1>, tensor<i32>
%11 = stablehlo.dynamic_slice %iterArg_0, %10, sizes = [1] : (tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
%12 = stablehlo.reshape %11 : (tensor<1xi32>) -> tensor<i32>
%13 = stablehlo.convert %iterArg_1 : tensor<i32>
%14 = stablehlo.add %13, %12 : tensor<i32>
stablehlo.return %4, %iterArg_0, %14 : tensor<i32>, tensor<3xi32>, tensor<i32>
}
return %2#0, %2#1, %2#2 : tensor<i32>, tensor<3xi32>, tensor<i32>
}
}
This above loop computes the sum of tensor elements. Running the code results in a segmentation fault.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')
jax-metal 0.0.7
Pretty sure the issue is specific to the dynamic slice inside while, I already run into this in several places, and removing the dynamic slice from the code makes it no longer segfault.
The dynamic slice prevents the backend to encode the whileOp. We are looking for the fix.
Running into the same issue in jax-metal 0.1.0
import jax
import jax.numpy as jnp
def f(x):
def scan_fn(h, w):
h_bne = w * h
return h_bne, None
return jax.lax.scan(scan_fn, x, jnp.array([[0.0]]))
x = jnp.ones(1)
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
jax.lax.scan hits segfault as well, and also has a dynamic_slice in lowered HLO.
my system info:
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:34:54) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='[redacted]', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May 1 20:16:51 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T8103', machine='arm64')
I have the same issue :(
I'm having the same issue as well
Hi @jonatanklosko
The issue seems to have been fixed in jax-metal 0.1.1. I tested the provided repro on M1 Pro Mac and it is executed without any error or segmentation fault.
>>> import jax
>>> import jax.numpy as jnp
>>>
>>>
>>> def f(x):
... def cond(carry):
... i, x, acc = carry
... return i < x.shape[0]
... def body(carry):
... i, x, acc = carry
... return (i + 1, x, acc + x[i])
... i = jnp.array(0)
... acc = jnp.array(0)
... return jax.lax.while_loop(cond, body, (i, x, acc))
...
>>>
>>> x = jnp.array([1, 2, 3])
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1733393163.091288 2444566 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro
systemMemory: 16.00 GB
maxCacheSize: 5.33 GB
I0000 00:00:1733393163.106153 2444566 service.cc:145] XLA service 0x600002c95f00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1733393163.106166 2444566 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1733393163.108507 2444566 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1733393163.108524 2444566 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
>>>
>>> # Print lowered HLO
>>> print(jax.jit(f).lower(x).as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3xi32>) -> (tensor<i32> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<i32> {jax.result_info = "[2]"}) {
%c = stablehlo.constant dense<0> : tensor<i32>
%c_0 = stablehlo.constant dense<0> : tensor<i32>
%0:3 = stablehlo.while(%iterArg = %c, %iterArg_1 = %arg0, %iterArg_2 = %c_0) : tensor<i32>, tensor<3xi32>, tensor<i32>
cond {
%c_3 = stablehlo.constant dense<3> : tensor<i32>
%1 = stablehlo.compare LT, %iterArg, %c_3, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %1 : tensor<i1>
} do {
%c_3 = stablehlo.constant dense<1> : tensor<i32>
%1 = stablehlo.add %iterArg, %c_3 : tensor<i32>
%c_4 = stablehlo.constant dense<0> : tensor<i32>
%2 = stablehlo.compare LT, %iterArg, %c_4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%3 = stablehlo.convert %iterArg : tensor<i32>
%c_5 = stablehlo.constant dense<3> : tensor<i32>
%4 = stablehlo.add %3, %c_5 : tensor<i32>
%5 = stablehlo.select %2, %4, %iterArg : tensor<i1>, tensor<i32>
%6 = stablehlo.dynamic_slice %iterArg_1, %5, sizes = [1] : (tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
%7 = stablehlo.reshape %6 : (tensor<1xi32>) -> tensor<i32>
%8 = stablehlo.convert %iterArg_2 : tensor<i32>
%9 = stablehlo.add %8, %7 : tensor<i32>
stablehlo.return %1, %iterArg_1, %9 : tensor<i32>, tensor<3xi32>, tensor<i32>
}
return %0#0, %0#1, %0#2 : tensor<i32>, tensor<3xi32>, tensor<i32>
}
}
>>> print(jax.jit(f)(x))
(Array(3, dtype=int32, weak_type=True), Array([1, 2, 3], dtype=int32), Array(6, dtype=int32))
>>> jax.print_environment_info()
jax: 0.4.35
jaxlib: 0.4.35
numpy: 2.1.3
python: 3.11.6 (v3.11.6:8b6ee5ba3b, Oct 2 2023, 11:18:21) [Clang 13.0.0 (clang-1300.0.29.30)]
device info: Metal-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='rajasekharp-macbookpro.roam.internal', release='24.1.0', version='Darwin Kernel Version 24.1.0: Thu Oct 10 21:03:15 PDT 2024; root:xnu-11215.41.3~2/RELEASE_ARM64_T6000', machine='arm64')
Could you please verify with jax-metal 0.1.1, if the issue still persists?
Thank you.
Yes, this appears to be solved in 0.1.1. The example from @abrasumente233 also works. Thanks :)