iree icon indicating copy to clipboard operation
iree copied to clipboard

Metal compile error

Open maxbartel opened this issue 1 year ago • 3 comments

What happened?

Hi! When trying to compile a torch model coming from iree-turbine for metal, the compilation fails.

Example input:

#executable_target_metal_msl_fb = #hal.executable.target<"metal-spirv", "metal-msl-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, Int8, Int16, Int64, Float16, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess, StoragePushConstant8, StorageUniform16, StorageBuffer16BitAccess, StoragePushConstant16, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, StoragePushConstant16, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Metal, Apple:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024 : i32, 1024 : i32, 1024 : i32]>>}>
#loc6 = loc("<stdin>":8:10)
#loc7 = loc("<stdin>":11:10)
#map = affine_map<(d0) -> (d0)>
#device_target_metal = #hal.device.target<"metal", [#executable_target_metal_msl_fb]>
module @module attributes {hal.device.targets = [#device_target_metal]} {
  util.func public @main$async(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
    %cst = arith.constant dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32> loc(#loc2)
    %cst_0 = arith.constant dense_resource<torch_tensor_4_3_torch.float32> : tensor<4x3xf32> loc(#loc2)
    %cst_1 = arith.constant 0.000000e+00 : f32 loc(#loc2)
    %0 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<4xf32> loc(#loc3)
    %expanded = tensor.expand_shape %0 [[0, 1]] : tensor<4xf32> into tensor<1x4xf32> loc(#loc4)
    %1 = tensor.empty() : tensor<1x3xf32> loc(#loc5)
    %2 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<1x3xf32>) -> tensor<1x3xf32> loc(#loc5)
    %3 = linalg.matmul ins(%expanded, %cst_0 : tensor<1x4xf32>, tensor<4x3xf32>) outs(%2 : tensor<1x3xf32>) -> tensor<1x3xf32> loc(#loc5)
    %collapsed = tensor.collapse_shape %3 [[0, 1]] : tensor<1x3xf32> into tensor<3xf32> loc(#loc6)
    %4 = tensor.empty() : tensor<3xf32> loc(#loc7)
    %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%collapsed, %cst : tensor<3xf32>, tensor<3xf32>) outs(%4 : tensor<3xf32>) {
    ^bb0(%in: f32 loc("<stdin>":8:10), %in_2: f32 loc("<stdin>":11:10), %out: f32 loc("<stdin>":11:10)):
      %8 = arith.addf %in, %in_2 : f32 loc(#loc7)
      linalg.yield %8 : f32 loc(#loc7)
    } -> tensor<3xf32> loc(#loc7)
    %6 = hal.tensor.barrier join(%5 : tensor<3xf32>) => %arg2 : !hal.fence loc(#loc1)
    %7 = hal.tensor.export %6 : tensor<3xf32> -> !hal.buffer_view loc(#loc1)
    util.return %7 : !hal.buffer_view loc(#loc8)
  } loc(#loc1)
  util.func public @main(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
    %c-1_i32 = arith.constant -1 : i32 loc(#loc1)
    %c0 = arith.constant 0 : index loc(#loc1)
    %device_0 = hal.devices.get %c0 : !hal.device loc(#loc1)
    %0 = util.null : !hal.fence loc(#loc1)
    %fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence loc(#loc1)
    %1 = util.call @main$async(%arg0, %0, %fence) : (!hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view loc(#loc1)
    %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) : i32 loc(#loc1)
    util.return %1 : !hal.buffer_view loc(#loc1)
  } loc(#loc1)
} loc(#loc)
#loc = loc("<stdin>":1:1)
#loc1 = loc("<stdin>":2:3)
#loc2 = loc(unknown)
#loc3 = loc("<stdin>":2:19)
#loc4 = loc("<stdin>":4:10)
#loc5 = loc("<stdin>":6:10)
#loc8 = loc("<stdin>":12:5)

{-#
  dialect_resources: {
    builtin: {
      torch_tensor_3_torch.float32: "0x040000001E042AC03DF18F3FCDF048BF",
      torch_tensor_4_3_torch.float32: "0x04000000074E023FBEDE61BE69FB4F3FE5A1D5BD7C478EBF23BB2CBF78099A3F98EAC93ECDD8613E8F522BBE869FE83E45A9113F"
    }
  }
#-}

Full error:

failed to translate executables
<unknown>:0: error: failed to legalize operation 'arith.constant'
<unknown>:0: note: see current operation: %10 = "arith.constant"() <{value = dense_resource<torch_tensor_4_3_torch.float32> : tensor<4x3xf32>}> : () -> tensor<4x3xf32>
<stdin>:11:10: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"metal-spirv", "metal-msl-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, Int8, Int16, Int64, Float16, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess, StoragePushConstant8, StorageUniform16, StorageBuffer16BitAccess, StoragePushConstant16, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, StoragePushConstant16, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Metal, Apple:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024 : i32, 1024 : i32, 1024 : i32]>>}>
<stdin>:11:10: note: see current operation: 
"hal.executable.variant"() ({
  "hal.executable.export"() ({
  ^bb0(%arg0: !hal.device):
    %0 = "arith.constant"() <{value = 1 : index}> : () -> index
    "hal.return"(%0, %0, %0) : (index, index, index) -> ()
  }) {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>], layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>, ordinal = 0 : index, sym_name = "main$async_dispatch_0_matmul_1x3x4_f32", translation_info = #iree_codegen.translation_info<SPIRVBaseDistribute>, workgroup_size = [32 : index, 1 : index, 1 : index]} : () -> ()
  "builtin.module"() ({
    "spirv.GlobalVariable"() <{binding = 0 : i32, descriptor_set = 0 : i32, sym_name = "__resource_var_0_0_", type = !spirv.ptr<none, StorageBuffer>}> : () -> ()
    "spirv.GlobalVariable"() <{binding = 1 : i32, descriptor_set = 0 : i32, sym_name = "__resource_var_0_1_", type = !spirv.ptr<none, StorageBuffer>}> : () -> ()
    "spirv.GlobalVariable"() <{binding = 2 : i32, descriptor_set = 0 : i32, sym_name = "__resource_var_0_2_", type = !spirv.ptr<none, StorageBuffer>}> : () -> ()
    "func.func"() <{function_type = () -> (), sym_name = "main$async_dispatch_0_matmul_1x3x4_f32"}> ({
      %0 = "arith.constant"() <{value = 4 : index}> : () -> index
      %1 = "arith.constant"() <{value = 1 : index}> : () -> index
      %2 = "arith.constant"() <{value = 3 : index}> : () -> index
      %3 = "arith.constant"() <{value = 0 : index}> : () -> index
      %4 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
      %5 = "arith.constant"() <{value = dense_resource<torch_tensor_4_3_torch.float32> : tensor<4x3xf32>}> : () -> tensor<4x3xf32>
      %6 = "hal.interface.binding.subspan"(%3, %0) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 1 : i32, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 1>, set = 0 : index} : (index, index) -> memref<?xf32, #spirv.storage_class<StorageBuffer>>
      %7 = "hal.interface.binding.subspan"(%3, %2) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 1 : i32, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 1>, set = 0 : index} : (index, index) -> memref<?xf32, #spirv.storage_class<StorageBuffer>>
      %8 = "hal.interface.binding.subspan"(%3, %2) {alignment = 64 : index, binding = 2 : index, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 1>, set = 0 : index} : (index, index) -> memref<?xf32, #spirv.storage_class<StorageBuffer>>
      %9 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<3xf32, #spirv.storage_class<Workgroup>>
      %10 = "gpu.thread_id"() <{dimension = #gpu<dim x>}> : () -> index
      %11 = "gpu.block_dim"() <{dimension = #gpu<dim x>}> : () -> index
      "scf.for"(%10, %2, %11) ({
      ^bb0(%arg0: index):
        "memref.store"(%4, %9, %arg0) <{nontemporal = false}> : (f32, memref<3xf32, #spirv.storage_class<Workgroup>>, index) -> ()
        "scf.yield"() : () -> ()
      }) : (index, index, index) -> ()
      "scf.for"(%10, %2, %11) ({
      ^bb0(%arg0: index):
        "scf.for"(%3, %0, %1) ({
        ^bb0(%arg1: index):
          %12 = "memref.load"(%6, %arg1) <{nontemporal = false}> : (memref<?xf32, #spirv.storage_class<StorageBuffer>>, index) -> f32
          %13 = "tensor.extract"(%5, %arg1, %arg0) : (tensor<4x3xf32>, index, index) -> f32
          %14 = "memref.load"(%9, %arg0) <{nontemporal = false}> : (memref<3xf32, #spirv.storage_class<Workgroup>>, index) -> f32
          %15 = "arith.mulf"(%12, %13) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
          %16 = "arith.addf"(%14, %15) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
          "memref.store"(%16, %9, %arg0) <{nontemporal = false}> : (f32, memref<3xf32, #spirv.storage_class<Workgroup>>, index) -> ()
          "scf.yield"() : () -> ()
        }) : (index, index, index) -> ()
        "scf.yield"() : () -> ()
      }) : (index, index, index) -> ()
      "scf.for"(%10, %2, %11) ({
      ^bb0(%arg0: index):
        %12 = "memref.load"(%9, %arg0) <{nontemporal = false}> : (memref<3xf32, #spirv.storage_class<Workgroup>>, index) -> f32
        %13 = "memref.load"(%7, %arg0) <{nontemporal = false}> : (memref<?xf32, #spirv.storage_class<StorageBuffer>>, index) -> f32
        %14 = "arith.addf"(%12, %13) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
        "memref.store"(%14, %8, %arg0) <{nontemporal = false}> : (f32, memref<?xf32, #spirv.storage_class<StorageBuffer>>, index) -> ()
        "scf.yield"() : () -> ()
      }) : (index, index, index) -> ()
      "func.return"() : () -> ()
    }) {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1]>} : () -> ()
  }) {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, Int8, Int16, Int64, Float16, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess, StoragePushConstant8, StorageUniform16, StorageBuffer16BitAccess, StoragePushConstant16, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, StoragePushConstant16, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Metal, Apple:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024 : i32, 1024 : i32, 1024 : i32]>>} : () -> ()
  "hal.executable.variant_end"() : () -> ()
}) {sym_name = "metal_msl_fb", target = #hal.executable.target<"metal-spirv", "metal-msl-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, Int8, Int16, Int64, Float16, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess, StoragePushConstant8, StorageUniform16, StorageBuffer16BitAccess, StoragePushConstant16, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, StoragePushConstant16, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Metal, Apple:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024 : i32, 1024 : i32, 1024 : i32]>>}>} : () -> ()

It looks like the dense_resource fails, but I don't have enough experience with SPIR-V to fix it myself...

Steps to reproduce your issue

  1. Run iree-compile --iree-hal-target-backends=metal-spirv on the input IR.

What component(s) does this issue relate to?

Compiler

Version information

Installed via pip and should be the newest release

IREE (https://iree.dev):
  IREE compiler version 20240410.859 @ b4273a4bfc66ba6dd8f62f6483d74d42a7b936f1
  LLVM version 19.0.0git
  Optimized build

Additional context

The metal docs on https://iree.dev/guides/deployment-configurations/gpu-metal/ are also missing, would be nice if a very simple example with the correct flags would be available... Had to look into the tests for them right now

maxbartel avatar Apr 30 '24 13:04 maxbartel

Version information

Installed via pip and should be the newest release

IREE (https://iree.dev):
  IREE compiler version 20240410.859 @ b4273a4bfc66ba6dd8f62f6483d74d42a7b936f1
  LLVM version 19.0.0git
  Optimized build

That is the latest stable release. Might want to also try the latest nightly (https://iree.dev/reference/bindings/python/#__tabbed_2_2):

python -m pip install --find-links https://iree.dev/pip-release-links.html --upgrade iree-compiler

failed to translate executables :0: error: failed to legalize operation 'arith.constant'

I saw a similar error on https://github.com/iree-org/iree/issues/17137 that was fixed by https://github.com/iree-org/iree/pull/17121. I'm not too confident that will also fix your issue here though, since that error message is quite generic for how far down in the compilation pipeline it occurs. Maybe someone more familiar with the workings of code generation could weigh in.


Additional context

The metal docs on https://iree.dev/guides/deployment-configurations/gpu-metal/ are also missing, would be nice if a very simple example with the correct flags would be available... Had to look into the tests for them right now

fyi @antiagainst

ScottTodd avatar Apr 30 '24 15:04 ScottTodd

That is the latest stable release. Might want to also try the latest nightly (https://iree.dev/reference/bindings/python/#__tabbed_2_2):

This does not fix the error unfortunately... If this is related to https://github.com/iree-org/iree/issues/17137 then this seems like a general SPIR-V bug and not something specific to metal.

maxbartel avatar Apr 30 '24 15:04 maxbartel

I created https://github.com/llvm/llvm-project/pull/91318 to fix this

maxbartel avatar May 07 '24 11:05 maxbartel

@ScottTodd My PR fixed this error. Maybe you could try https://github.com/iree-org/iree/issues/17137 to see if it is also fixed 🙂

maxbartel avatar May 27 '24 10:05 maxbartel

@ScottTodd My PR fixed this error. Maybe you could try #17137 to see if it is also fixed 🙂

Ah, thanks for the note. The test failures on that other issue are still there (they are marked "expected to fail" / XFAIL and will block CI runs if they happen to start passing): https://github.com/iree-org/iree/actions/runs/9260597150/job/25474870074#step:9:86

ScottTodd avatar May 28 '24 15:05 ScottTodd