jax icon indicating copy to clipboard operation
jax copied to clipboard

jax-metal: reduce window not supported

Open jonatanklosko opened this issue 1 month ago • 0 comments

Description

import jax
import jax.numpy as jnp

def f(x):
  return jax.lax.reduce_window(
            x, 0, jnp.add, window_dimensions=(2,), window_strides=(1,),
            padding=[(0, 0)], base_dilation=(1,), window_dilation=(1,))

x = jnp.array([1, 2, 4])

# Print lowered HLO
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
HLO
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xi32> {mhlo.layout_mode = "default"}) -> (tensor<4xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %c = stablehlo.constant dense<0> : tensor<i32>
    %0 = "stablehlo.reduce_window"(%arg0, %c) <{base_dilations = array<i64: 1>, padding = dense<1> : tensor<1x2xi64>, window_dilations = array<i64: 1>, window_dimensions = array<i64: 2>, window_strides = array<i64: 1>}> ({
    ^bb0(%arg1: tensor<i32>, %arg2: tensor<i32>):
      %1 = stablehlo.add %arg1, %arg2 : tensor<i32>
      stablehlo.return %1 : tensor<i32>
    }) : (tensor<3xi32>, tensor<i32>) -> tensor<4xi32>
    return %0 : tensor<4xi32>
  }
}

fails with

Traceback (most recent call last):
  File "/Users/jonatanklosko/tmp/jax_mlir.py", line 71, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/jonatanklosko/tmp/jax_mlir.py:63:0: error: failed to legalize operation 'mhlo.reduce_window'
/Users/jonatanklosko/tmp/jax_mlir.py:70:0: note: called from
/Users/jonatanklosko/tmp/jax_mlir.py:63:0: note: see current operation:
%2 = "mhlo.reduce_window"(%arg0, %1) ({
^bb0(%arg1: tensor<si32>, %arg2: tensor<si32>):
  %3 = "mhlo.add"(%arg1, %arg2) : (tensor<si32>, tensor<si32>) -> tensor<si32>
  "mhlo.return"(%3) : (tensor<si32>) -> ()
}) {base_dilations = dense<1> : tensor<1xi64>, padding = dense<0> : tensor<1x2xi64>, window_dilations = dense<1> : tensor<1xi64>, window_dimensions = dense<2> : tensor<1xi64>, window_strides = dense<1> : tensor<1xi64>} : (tensor<3xsi32>, tensor<si32>) -> tensor<2xsi32>

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

jonatanklosko avatar May 23 '24 10:05 jonatanklosko