jax icon indicating copy to clipboard operation
jax copied to clipboard

jax-metal: segmentation fault inside `jax.lax.while_loop`

Open jonatanklosko opened this issue 1 year ago • 4 comments
trafficstars

Description

import jax
import jax.numpy as jnp

def f(x):
  def cond(carry):
    i, x, acc = carry
    return i < x.shape[0]

  def body(carry):
    i, x, acc = carry
    return (i + 1, x, acc + x[i])

  i = jnp.array(0)
  acc = jnp.array(0)

  return jax.lax.while_loop(cond, body, (i, x, acc))

x = jnp.array([1, 2, 3])

# Print lowered HLO
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
HLO
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xi32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<i32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.constant dense<0> : tensor<i32>
    %1 = stablehlo.constant dense<1> : tensor<i32>
    %2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %arg0, %iterArg_1 = %1) : tensor<i32>, tensor<3xi32>, tensor<i32>
     cond {
      %3 = stablehlo.constant dense<3> : tensor<i32>
      %4 = stablehlo.compare  LT, %iterArg, %3,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      stablehlo.return %4 : tensor<i1>
    } do {
      %3 = stablehlo.constant dense<1> : tensor<i32>
      %4 = stablehlo.add %iterArg, %3 : tensor<i32>
      %5 = stablehlo.constant dense<0> : tensor<i32>
      %6 = stablehlo.compare  LT, %iterArg, %5,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      %7 = stablehlo.convert %iterArg : tensor<i32>
      %8 = stablehlo.constant dense<3> : tensor<i32>
      %9 = stablehlo.add %7, %8 : tensor<i32>
      %10 = stablehlo.select %6, %9, %iterArg : tensor<i1>, tensor<i32>
      %11 = stablehlo.dynamic_slice %iterArg_0, %10, sizes = [1] : (tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
      %12 = stablehlo.reshape %11 : (tensor<1xi32>) -> tensor<i32>
      %13 = stablehlo.convert %iterArg_1 : tensor<i32>
      %14 = stablehlo.add %13, %12 : tensor<i32>
      stablehlo.return %4, %iterArg_0, %14 : tensor<i32>, tensor<3xi32>, tensor<i32>
    }
    return %2#0, %2#1, %2#2 : tensor<i32>, tensor<3xi32>, tensor<i32>
  }
}

This above loop computes the sum of tensor elements. Running the code results in a segmentation fault.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

jonatanklosko avatar May 31 '24 12:05 jonatanklosko

Pretty sure the issue is specific to the dynamic slice inside while, I already run into this in several places, and removing the dynamic slice from the code makes it no longer segfault.

jonatanklosko avatar May 31 '24 14:05 jonatanklosko

The dynamic slice prevents the backend to encode the whileOp. We are looking for the fix.

shuhand0 avatar Jun 03 '24 20:06 shuhand0

Running into the same issue in jax-metal 0.1.0

acranej avatar Jun 11 '24 19:06 acranej

import jax
import jax.numpy as jnp

def f(x):
    def scan_fn(h, w):
        h_bne = w * h
        return h_bne, None

    return jax.lax.scan(scan_fn, x, jnp.array([[0.0]]))

x = jnp.ones(1)
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))

jax.lax.scan hits segfault as well, and also has a dynamic_slice in lowered HLO.

my system info:

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:34:54) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='[redacted]', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:16:51 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T8103', machine='arm64')

rumisle avatar Jun 22 '24 16:06 rumisle

I have the same issue :(

aniquetahir avatar Jul 01 '24 22:07 aniquetahir

I'm having the same issue as well

vyeevani avatar Jul 19 '24 16:07 vyeevani

Hi @jonatanklosko

The issue seems to have been fixed in jax-metal 0.1.1. I tested the provided repro on M1 Pro Mac and it is executed without any error or segmentation fault.

>>> import jax
>>> import jax.numpy as jnp
>>> 
>>> 
>>> def f(x):
...   def cond(carry):
...     i, x, acc = carry
...     return i < x.shape[0]
...   def body(carry):
...     i, x, acc = carry
...     return (i + 1, x, acc + x[i])
...   i = jnp.array(0)
...   acc = jnp.array(0)
...   return jax.lax.while_loop(cond, body, (i, x, acc))
... 
>>> 
>>> x = jnp.array([1, 2, 3])
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1733393163.091288 2444566 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

I0000 00:00:1733393163.106153 2444566 service.cc:145] XLA service 0x600002c95f00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1733393163.106166 2444566 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1733393163.108507 2444566 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1733393163.108524 2444566 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
>>> 
>>> # Print lowered HLO
>>> print(jax.jit(f).lower(x).as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xi32>) -> (tensor<i32> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<i32> {jax.result_info = "[2]"}) {
    %c = stablehlo.constant dense<0> : tensor<i32>
    %c_0 = stablehlo.constant dense<0> : tensor<i32>
    %0:3 = stablehlo.while(%iterArg = %c, %iterArg_1 = %arg0, %iterArg_2 = %c_0) : tensor<i32>, tensor<3xi32>, tensor<i32>
     cond {
      %c_3 = stablehlo.constant dense<3> : tensor<i32>
      %1 = stablehlo.compare  LT, %iterArg, %c_3,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      stablehlo.return %1 : tensor<i1>
    } do {
      %c_3 = stablehlo.constant dense<1> : tensor<i32>
      %1 = stablehlo.add %iterArg, %c_3 : tensor<i32>
      %c_4 = stablehlo.constant dense<0> : tensor<i32>
      %2 = stablehlo.compare  LT, %iterArg, %c_4,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      %3 = stablehlo.convert %iterArg : tensor<i32>
      %c_5 = stablehlo.constant dense<3> : tensor<i32>
      %4 = stablehlo.add %3, %c_5 : tensor<i32>
      %5 = stablehlo.select %2, %4, %iterArg : tensor<i1>, tensor<i32>
      %6 = stablehlo.dynamic_slice %iterArg_1, %5, sizes = [1] : (tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
      %7 = stablehlo.reshape %6 : (tensor<1xi32>) -> tensor<i32>
      %8 = stablehlo.convert %iterArg_2 : tensor<i32>
      %9 = stablehlo.add %8, %7 : tensor<i32>
      stablehlo.return %1, %iterArg_1, %9 : tensor<i32>, tensor<3xi32>, tensor<i32>
    }
    return %0#0, %0#1, %0#2 : tensor<i32>, tensor<3xi32>, tensor<i32>
  }
}

>>> print(jax.jit(f)(x))
(Array(3, dtype=int32, weak_type=True), Array([1, 2, 3], dtype=int32), Array(6, dtype=int32))
>>> jax.print_environment_info()
jax:    0.4.35
jaxlib: 0.4.35
numpy:  2.1.3
python: 3.11.6 (v3.11.6:8b6ee5ba3b, Oct  2 2023, 11:18:21) [Clang 13.0.0 (clang-1300.0.29.30)]
device info: Metal-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='rajasekharp-macbookpro.roam.internal', release='24.1.0', version='Darwin Kernel Version 24.1.0: Thu Oct 10 21:03:15 PDT 2024; root:xnu-11215.41.3~2/RELEASE_ARM64_T6000', machine='arm64')

Could you please verify with jax-metal 0.1.1, if the issue still persists?

Thank you.

rajasekharporeddy avatar Dec 05 '24 10:12 rajasekharporeddy

Yes, this appears to be solved in 0.1.1. The example from @abrasumente233 also works. Thanks :)

jonatanklosko avatar Dec 05 '24 11:12 jonatanklosko