iree icon indicating copy to clipboard operation
iree copied to clipboard

Compiling ASR frontend with VMVX causes 'util.buffer.store' complex operand error

Open phoenix-meadowlark opened this issue 3 years ago • 2 comments

What happened?

Compiling the ASR frontend with VMVX causes the following error:

/tmp/iree/libri/compute_frontend.mlir:89:10: error:
  'util.buffer.store' op operand #0 must be index or integer or floating-point, but got 'complex<f32>'
    %0 = "mhlo.fft"(%arg0) {fft_length = dense<256> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
         ^

(This has a lower priority than the other issues I've been posting since it only affects VMVX).

Steps to reproduce your issue

Simplified MLIR (no error is raised if the mhlo.convolution is elided in the python):

module @jit_compute_frontend {
  func.func public @main(%arg0: tensor<1x1600x1xf32>) -> tensor<1x10x129x1xf32> {
    %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<160x160xi32>
    %1 = mhlo.constant dense<0> : tensor<i32>
    %2 = "mhlo.broadcast_in_dim"(%1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<160x160xi32>
    %3 = mhlo.add %0, %2 : tensor<160x160xi32>
    %4 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<160x160xi32>
    %5 = mhlo.compare  EQ, %3, %4,  SIGNED : (tensor<160x160xi32>, tensor<160x160xi32>) -> tensor<160x160xi1>
    %6 = mhlo.convert %5 : (tensor<160x160xi1>) -> tensor<160x160xf32>
    %7 = mhlo.reshape %6 : (tensor<160x160xf32>) -> tensor<160x1x160xf32>
    %8 = mhlo.reshape %7 : (tensor<160x1x160xf32>) -> tensor<1x160x1x1x1x160xf32>
    %9 = mhlo.reshape %8 : (tensor<1x160x1x1x1x160xf32>) -> tensor<160x1x160xf32>
    %10 = mhlo.convolution(%arg0, %9) dim_numbers = [b, 0, f]x[o, i, 0]->[b, 0, f], window = {stride = [160], pad = [[0, 0]], lhs_dilate = [1], rhs_dilate = [1], reverse = [0]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<1x1600x1xf32>, tensor<160x1x160xf32>) -> tensor<1x10x160xf32>
    %11 = mhlo.reshape %10 : (tensor<1x10x160xf32>) -> tensor<1x10x1x160xf32>
    %12 = "mhlo.transpose"(%11) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x10x1x160xf32>) -> tensor<1x10x160x1xf32>
    %13 = "mhlo.transpose"(%12) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x10x160x1xf32>) -> tensor<1x10x1x160xf32>
    %14 = mhlo.constant dense<0> : tensor<i32>
    %15 = call @_pad(%13, %14) : (tensor<1x10x1x160xf32>, tensor<i32>) -> tensor<1x10x1x256xf32>
    %16 = call @fft(%15) : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
    %17 = "mhlo.transpose"(%16) {permutation = dense<[0, 1, 3, 2]> : tensor<4xi64>} : (tensor<1x10x1x129xcomplex<f32>>) -> tensor<1x10x129x1xcomplex<f32>>
    %18 = mhlo.abs %17 : (tensor<1x10x129x1xcomplex<f32>>) -> tensor<1x10x129x1xf32>
    return %18 : tensor<1x10x129x1xf32>
  }
  func.func private @_pad(%arg0: tensor<1x10x1x160xf32>, %arg1: tensor<i32>) -> tensor<1x10x1x256xf32> {
    %0 = "mhlo.broadcast_in_dim"(%arg1) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<4x2xi32>
    %1 = mhlo.convert %0 : (tensor<4x2xi32>) -> tensor<4x2xf32>
    %2 = mhlo.constant dense<0> : tensor<i32>
    %3 = "mhlo.broadcast_in_dim"(%2) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %4 = mhlo.constant dense<0> : tensor<i32>
    %5 = "mhlo.broadcast_in_dim"(%4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %6 = "mhlo.concatenate"(%3, %5) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %7 = "mhlo.gather"(%1, %6) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %8 = "mhlo.pad"(%arg0, %7) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x160xf32>
    %9 = mhlo.constant dense<0> : tensor<i32>
    %10 = "mhlo.broadcast_in_dim"(%9) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %11 = mhlo.constant dense<1> : tensor<i32>
    %12 = "mhlo.broadcast_in_dim"(%11) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %13 = "mhlo.concatenate"(%10, %12) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %14 = "mhlo.gather"(%1, %13) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %15 = "mhlo.pad"(%8, %14) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x160xf32>
    %16 = mhlo.constant dense<1> : tensor<i32>
    %17 = "mhlo.broadcast_in_dim"(%16) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %18 = mhlo.constant dense<0> : tensor<i32>
    %19 = "mhlo.broadcast_in_dim"(%18) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %20 = "mhlo.concatenate"(%17, %19) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %21 = "mhlo.gather"(%1, %20) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %22 = "mhlo.pad"(%15, %21) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x160xf32>
    %23 = mhlo.constant dense<1> : tensor<i32>
    %24 = "mhlo.broadcast_in_dim"(%23) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %25 = mhlo.constant dense<1> : tensor<i32>
    %26 = "mhlo.broadcast_in_dim"(%25) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %27 = "mhlo.concatenate"(%24, %26) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %28 = "mhlo.gather"(%1, %27) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %29 = "mhlo.pad"(%22, %28) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x160xf32>
    %30 = mhlo.constant dense<2> : tensor<i32>
    %31 = "mhlo.broadcast_in_dim"(%30) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %32 = mhlo.constant dense<0> : tensor<i32>
    %33 = "mhlo.broadcast_in_dim"(%32) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %34 = "mhlo.concatenate"(%31, %33) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %35 = "mhlo.gather"(%1, %34) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %36 = "mhlo.pad"(%29, %35) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x160xf32>
    %37 = mhlo.constant dense<2> : tensor<i32>
    %38 = "mhlo.broadcast_in_dim"(%37) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %39 = mhlo.constant dense<1> : tensor<i32>
    %40 = "mhlo.broadcast_in_dim"(%39) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %41 = "mhlo.concatenate"(%38, %40) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %42 = "mhlo.gather"(%1, %41) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %43 = "mhlo.pad"(%36, %42) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x160xf32>
    %44 = mhlo.constant dense<3> : tensor<i32>
    %45 = "mhlo.broadcast_in_dim"(%44) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %46 = mhlo.constant dense<0> : tensor<i32>
    %47 = "mhlo.broadcast_in_dim"(%46) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %48 = "mhlo.concatenate"(%45, %47) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %49 = "mhlo.gather"(%1, %48) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %50 = "mhlo.pad"(%43, %49) {edge_padding_high = dense<0> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x160xf32>
    %51 = mhlo.constant dense<3> : tensor<i32>
    %52 = "mhlo.broadcast_in_dim"(%51) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %53 = mhlo.constant dense<1> : tensor<i32>
    %54 = "mhlo.broadcast_in_dim"(%53) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>) -> tensor<1xi32>
    %55 = "mhlo.concatenate"(%52, %54) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
    %56 = "mhlo.gather"(%1, %55) {dimension_numbers = #mhlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1]>, indices_are_sorted = true, slice_sizes = dense<1> : tensor<2xi64>} : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<f32>
    %57 = "mhlo.pad"(%50, %56) {edge_padding_high = dense<[0, 0, 0, 96]> : tensor<4xi64>, edge_padding_low = dense<0> : tensor<4xi64>, interior_padding = dense<0> : tensor<4xi64>} : (tensor<1x10x1x160xf32>, tensor<f32>) -> tensor<1x10x1x256xf32>
    return %57 : tensor<1x10x1x256xf32>
  }
  func.func private @fft(%arg0: tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>> {
    %0 = "mhlo.fft"(%arg0) {fft_length = dense<256> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
    return %0 : tensor<1x10x1x129xcomplex<f32>>
  }
}
iree-compile \
  --iree-hal-target-backends=vmvx \
  --iree-input-type=mhlo \
  /tmp/compute_frontend.mlir \
  -o /tmp/compute_frontend.vmfb
/tmp/iree/libri/compute_frontend.mlir:89:10: error: 'util.buffer.store' op operand #0 must be index or integer or floating-point, but got 'complex<f32>'
    %0 = "mhlo.fft"(%arg0) {fft_length = dense<256> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
         ^
/tmp/iree/libri/compute_frontend.mlir:20:11: note: called from
    %16 = "func.call"(%15) {callee = @fft} : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
          ^
/tmp/iree/libri/compute_frontend.mlir:89:10: note: see current operation: "util.buffer.store"(%36, %23, %22, %39) : (complex<f32>, !util.buffer, index, index) -> ()
    %0 = "mhlo.fft"(%arg0) {fft_length = dense<256> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
         ^
/tmp/iree/libri/compute_frontend.mlir:89:10: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"vmvx", "vmvx-bytecode-fb">
    %0 = "mhlo.fft"(%arg0) {fft_length = dense<256> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>} : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
         ^
/tmp/iree/libri/compute_frontend.mlir:20:11: note: called from
    %16 = "func.call"(%15) {callee = @fft} : (tensor<1x10x1x256xf32>) -> tensor<1x10x1x129xcomplex<f32>>
          ^
/tmp/iree/libri/compute_frontend.mlir:89:10: note: see current operation:
"hal.executable.variant"() ({
  "hal.executable.export"() ({
  ^bb0(%arg0: !hal.device, %arg1: index, %arg2: index):
    %0 = "arith.constant"() {value = 3 : index} : () -> index
    %1 = "arith.constant"() {value = 5 : index} : () -> index
    %2 = "arith.constant"() {value = 1 : index} : () -> index
    "hal.return"(%0, %1, %2) : (index, index, index) -> ()
  }) {layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>, ordinal = 0 : index, sym_name = "_main_dispatch_13_generic_10x129", translation_info = #iree_codegen.translation_info<VMVXDefault>} : () -> ()
  "builtin.module"() ({
    "func.func"() ({
    ^bb0(%arg0: !util.buffer, %arg1: !util.buffer, %arg2: !util.list<!util.buffer>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32, %arg9: i32, %arg10: i32, %arg11: i32):
      %0 = "arith.constant"() {value = 0 : index} : () -> index
      %1 = "arith.constant"() {value = 1290 : index} : () -> index
      %2 = "arith.constant"() {value = 2 : index} : () -> index
      %3 = "arith.constant"() {value = 1 : index} : () -> index
      %4 = "arith.constant"() {value = 43 : index} : () -> index
      %5 = "arith.constant"() {value = 10240 : index} : () -> index
      %6 = "arith.constant"() {value = 20480 : index} : () -> index
      %7 = "arith.constant"() {value = 2560 : index} : () -> index
      %8 = "arith.constant"() {value = 0 : index} : () -> index
      %9 = "util.list.get"(%arg2, %8) : (!util.list<!util.buffer>, index) -> !util.buffer
      %10 = "arith.constant"() {value = 0 : index} : () -> index
      %11 = "util.list.get"(%arg2, %10) : (!util.list<!util.buffer>, index) -> !util.buffer
      %12 = "util.buffer.size"(%11) : (!util.buffer) -> index
      %13 = "arith.constant"() {value = 4 : index} : () -> index
      %14 = "arith.constant"() {value = 10240 : index} : () -> index
      %15 = "util.buffer.subspan"(%11, %12, %5, %14) : (!util.buffer, index, index, index) -> !util.buffer
      %16 = "arith.constant"() {value = 0 : index} : () -> index
      %17 = "util.list.get"(%arg2, %16) : (!util.list<!util.buffer>, index) -> !util.buffer
      %18 = "arith.constant"() {value = 1 : index} : () -> index
      %19 = "util.list.get"(%arg2, %18) : (!util.list<!util.buffer>, index) -> !util.buffer
      %20 = "util.buffer.size"(%19) : (!util.buffer) -> index
      %21 = "util.sizeof"() {sizedType = complex<f32>} : () -> index
      %22 = "arith.muli"(%21, %1) : (index, index) -> index
      %23 = "util.buffer.subspan"(%19, %20, %6, %22) : (!util.buffer, index, index, index) -> !util.buffer
      %24 = "arith.index_cast"(%arg3) : (i32) -> index
      %25 = "arith.index_cast"(%arg4) : (i32) -> index
      "scf.for"(%0, %2, %3) ({
      ^bb0(%arg12: index):
        "scf.for"(%0, %4, %3) ({
        ^bb0(%arg13: index):
          %26 = "affine.apply"(%arg12, %arg13, %25, %24) {map = affine_map<(d0, d1)[s0, s1] -> (d0 * 256 + d1 + s0 * 512 + s1 * 43)>} : (index, index, index, index) -> index
          %27 = "util.buffer.size"(%9) : (!util.buffer) -> index
          %28 = "arith.constant"() {value = 4 : index} : () -> index
          %29 = "arith.muli"(%28, %26) : (index, index) -> index
          %30 = "util.buffer.load"(%9, %27, %29) : (!util.buffer, index, index) -> f32
          %31 = "affine.apply"(%arg12, %arg13, %25, %24) {map = affine_map<(d0, d1)[s0, s1] -> (d0 * 256 + d1 + s0 * 512 + s1 * 43 + 2560)>} : (index, index, index, index) -> index
          %32 = "util.buffer.size"(%17) : (!util.buffer) -> index
          %33 = "arith.constant"() {value = 4 : index} : () -> index
          %34 = "arith.muli"(%33, %31) : (index, index) -> index
          %35 = "util.buffer.load"(%17, %32, %34) : (!util.buffer, index, index) -> f32
          %36 = "complex.create"(%30, %35) : (f32, f32) -> complex<f32>
          %37 = "affine.apply"(%arg12, %arg13, %25, %24) {map = affine_map<(d0, d1)[s0, s1] -> (d0 * 129 + d1 + s0 * 258 + s1 * 43)>} : (index, index, index, index) -> index
          %38 = "util.sizeof"() {sizedType = complex<f32>} : () -> index
          %39 = "arith.muli"(%38, %37) : (index, index) -> index
          "util.buffer.store"(%36, %23, %22, %39) : (complex<f32>, !util.buffer, index, index) -> ()
          "scf.yield"() : () -> ()
        }) : (index, index, index) -> ()
        "scf.yield"() : () -> ()
      }) : (index, index, index) -> ()
      "func.return"() : () -> ()
    }) {function_type = (!util.buffer, !util.buffer, !util.list<!util.buffer>, i32, i32, i32, i32, i32, i32, i32, i32, i32) -> (), sym_name = "_main_dispatch_13_generic_10x129"} : () -> ()
  }) : () -> ()
  "hal.executable.variant_end"() : () -> ()
}) {sym_name = "vmvx_bytecode_fb", target = #hal.executable.target<"vmvx", "vmvx-bytecode-fb">} : () -> ()

... the rest seemed fairly redundant

What component(s) does this issue relate to?

Compiler

Version information

27960a3246e41acfa79b1101f625fc0a42b404ed

Additional context

No response

phoenix-meadowlark avatar Nov 02 '22 00:11 phoenix-meadowlark

Cool, this is probably the first complex usage down to this layer. It doesn't look like there's a transform that does a complex tensor -> flattened 2xf32 tensor yet, just an mhlo-level at compiler/src/iree/compiler/InputConversion/MHLO/ConvertComplexToReal.cpp. This would need to run through pre-util.buffer IR during codegen to insert the x2 dimension as once we've type-erased with the !util.buffer we can't change the indexing like that (element N becomes N*2 but only in the original interpretation of the tensor).

benvanik avatar Nov 02 '22 03:11 benvanik

Sorry to revive an old bug (let me know if I should log a new issue instead), but I ran into the same thing. I don't use the VMVX backend at all and only noticed it when updating the stablehlo e2e tests. So this isn't a priority for me, but I do have it down to a nice simple repro. The following test compiles with llvm-cpu and fails with vmvx.

func.func @complex_extract() {
  %input = util.unfoldable_constant dense<[(0.1,0.2), (0.3,0.4), (0.5,0.6), (0.7,0.8)]> : tensor<4xcomplex<f32>>
  %1 = stablehlo.real %input : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
  %2 = stablehlo.imag %input : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
  check.expect_almost_eq_const(%1, dense<[0.1,0.3,0.5,0.7]> : tensor<4xf32>) : tensor<4xf32>
  check.expect_almost_eq_const(%2, dense<[0.2,0.4,0.6,0.8]> : tensor<4xf32>) : tensor<4xf32>
  return
}

As a side note, the existing stablehlo e2e test for complex numbers (iree/tests/e2e/stablehlo_ops/complex.mlir) works, but it's almost by accident. The test creates complex numbers from floats, does an operation, then splits them apart again. This works in VMVX, but only because the compiler is smart enough to never bother actually making the complex numbers in the first place.

pstarkcdpr avatar Dec 05 '25 01:12 pstarkcdpr