jax icon indicating copy to clipboard operation
jax copied to clipboard

jax-metal: segmentation fault inside `jax.lax.while_loop`

Open jonatanklosko opened this issue 1 month ago • 4 comments

Description

import jax
import jax.numpy as jnp

def f(x):
  def cond(carry):
    i, x, acc = carry
    return i < x.shape[0]

  def body(carry):
    i, x, acc = carry
    return (i + 1, x, acc + x[i])

  i = jnp.array(0)
  acc = jnp.array(0)

  return jax.lax.while_loop(cond, body, (i, x, acc))

x = jnp.array([1, 2, 3])

# Print lowered HLO
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
HLO
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xi32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<i32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.constant dense<0> : tensor<i32>
    %1 = stablehlo.constant dense<1> : tensor<i32>
    %2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %arg0, %iterArg_1 = %1) : tensor<i32>, tensor<3xi32>, tensor<i32>
     cond {
      %3 = stablehlo.constant dense<3> : tensor<i32>
      %4 = stablehlo.compare  LT, %iterArg, %3,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      stablehlo.return %4 : tensor<i1>
    } do {
      %3 = stablehlo.constant dense<1> : tensor<i32>
      %4 = stablehlo.add %iterArg, %3 : tensor<i32>
      %5 = stablehlo.constant dense<0> : tensor<i32>
      %6 = stablehlo.compare  LT, %iterArg, %5,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      %7 = stablehlo.convert %iterArg : tensor<i32>
      %8 = stablehlo.constant dense<3> : tensor<i32>
      %9 = stablehlo.add %7, %8 : tensor<i32>
      %10 = stablehlo.select %6, %9, %iterArg : tensor<i1>, tensor<i32>
      %11 = stablehlo.dynamic_slice %iterArg_0, %10, sizes = [1] : (tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
      %12 = stablehlo.reshape %11 : (tensor<1xi32>) -> tensor<i32>
      %13 = stablehlo.convert %iterArg_1 : tensor<i32>
      %14 = stablehlo.add %13, %12 : tensor<i32>
      stablehlo.return %4, %iterArg_0, %14 : tensor<i32>, tensor<3xi32>, tensor<i32>
    }
    return %2#0, %2#1, %2#2 : tensor<i32>, tensor<3xi32>, tensor<i32>
  }
}

This above loop computes the sum of tensor elements. Running the code results in a segmentation fault.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

jonatanklosko avatar May 31 '24 12:05 jonatanklosko