Metal compile error
What happened?
Hi! When trying to compile a torch model coming from iree-turbine for metal, the compilation fails.
Example input:
#executable_target_metal_msl_fb = #hal.executable.target<"metal-spirv", "metal-msl-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, Int8, Int16, Int64, Float16, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess, StoragePushConstant8, StorageUniform16, StorageBuffer16BitAccess, StoragePushConstant16, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, StoragePushConstant16, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Metal, Apple:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024 : i32, 1024 : i32, 1024 : i32]>>}>
#loc6 = loc("<stdin>":8:10)
#loc7 = loc("<stdin>":11:10)
#map = affine_map<(d0) -> (d0)>
#device_target_metal = #hal.device.target<"metal", [#executable_target_metal_msl_fb]>
module @module attributes {hal.device.targets = [#device_target_metal]} {
util.func public @main$async(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view attributes {inlining_policy = #util.inline.never, iree.abi.model = "coarse-fences", iree.abi.stub} {
%cst = arith.constant dense_resource<torch_tensor_3_torch.float32> : tensor<3xf32> loc(#loc2)
%cst_0 = arith.constant dense_resource<torch_tensor_4_3_torch.float32> : tensor<4x3xf32> loc(#loc2)
%cst_1 = arith.constant 0.000000e+00 : f32 loc(#loc2)
%0 = hal.tensor.import wait(%arg1) => %arg0 : !hal.buffer_view -> tensor<4xf32> loc(#loc3)
%expanded = tensor.expand_shape %0 [[0, 1]] : tensor<4xf32> into tensor<1x4xf32> loc(#loc4)
%1 = tensor.empty() : tensor<1x3xf32> loc(#loc5)
%2 = linalg.fill ins(%cst_1 : f32) outs(%1 : tensor<1x3xf32>) -> tensor<1x3xf32> loc(#loc5)
%3 = linalg.matmul ins(%expanded, %cst_0 : tensor<1x4xf32>, tensor<4x3xf32>) outs(%2 : tensor<1x3xf32>) -> tensor<1x3xf32> loc(#loc5)
%collapsed = tensor.collapse_shape %3 [[0, 1]] : tensor<1x3xf32> into tensor<3xf32> loc(#loc6)
%4 = tensor.empty() : tensor<3xf32> loc(#loc7)
%5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%collapsed, %cst : tensor<3xf32>, tensor<3xf32>) outs(%4 : tensor<3xf32>) {
^bb0(%in: f32 loc("<stdin>":8:10), %in_2: f32 loc("<stdin>":11:10), %out: f32 loc("<stdin>":11:10)):
%8 = arith.addf %in, %in_2 : f32 loc(#loc7)
linalg.yield %8 : f32 loc(#loc7)
} -> tensor<3xf32> loc(#loc7)
%6 = hal.tensor.barrier join(%5 : tensor<3xf32>) => %arg2 : !hal.fence loc(#loc1)
%7 = hal.tensor.export %6 : tensor<3xf32> -> !hal.buffer_view loc(#loc1)
util.return %7 : !hal.buffer_view loc(#loc8)
} loc(#loc1)
util.func public @main(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub} {
%c-1_i32 = arith.constant -1 : i32 loc(#loc1)
%c0 = arith.constant 0 : index loc(#loc1)
%device_0 = hal.devices.get %c0 : !hal.device loc(#loc1)
%0 = util.null : !hal.fence loc(#loc1)
%fence = hal.fence.create device(%device_0 : !hal.device) flags("None") : !hal.fence loc(#loc1)
%1 = util.call @main$async(%arg0, %0, %fence) : (!hal.buffer_view, !hal.fence, !hal.fence) -> !hal.buffer_view loc(#loc1)
%status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) : i32 loc(#loc1)
util.return %1 : !hal.buffer_view loc(#loc1)
} loc(#loc1)
} loc(#loc)
#loc = loc("<stdin>":1:1)
#loc1 = loc("<stdin>":2:3)
#loc2 = loc(unknown)
#loc3 = loc("<stdin>":2:19)
#loc4 = loc("<stdin>":4:10)
#loc5 = loc("<stdin>":6:10)
#loc8 = loc("<stdin>":12:5)
{-#
dialect_resources: {
builtin: {
torch_tensor_3_torch.float32: "0x040000001E042AC03DF18F3FCDF048BF",
torch_tensor_4_3_torch.float32: "0x04000000074E023FBEDE61BE69FB4F3FE5A1D5BD7C478EBF23BB2CBF78099A3F98EAC93ECDD8613E8F522BBE869FE83E45A9113F"
}
}
#-}
Full error:
failed to translate executables
<unknown>:0: error: failed to legalize operation 'arith.constant'
<unknown>:0: note: see current operation: %10 = "arith.constant"() <{value = dense_resource<torch_tensor_4_3_torch.float32> : tensor<4x3xf32>}> : () -> tensor<4x3xf32>
<stdin>:11:10: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"metal-spirv", "metal-msl-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, Int8, Int16, Int64, Float16, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess, StoragePushConstant8, StorageUniform16, StorageBuffer16BitAccess, StoragePushConstant16, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, StoragePushConstant16, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Metal, Apple:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024 : i32, 1024 : i32, 1024 : i32]>>}>
<stdin>:11:10: note: see current operation:
"hal.executable.variant"() ({
"hal.executable.export"() ({
^bb0(%arg0: !hal.device):
%0 = "arith.constant"() <{value = 1 : index}> : () -> index
"hal.return"(%0, %0, %0) : (index, index, index) -> ()
}) {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>, #hal.interface.binding<0, 2>], layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>, ordinal = 0 : index, sym_name = "main$async_dispatch_0_matmul_1x3x4_f32", translation_info = #iree_codegen.translation_info<SPIRVBaseDistribute>, workgroup_size = [32 : index, 1 : index, 1 : index]} : () -> ()
"builtin.module"() ({
"spirv.GlobalVariable"() <{binding = 0 : i32, descriptor_set = 0 : i32, sym_name = "__resource_var_0_0_", type = !spirv.ptr<none, StorageBuffer>}> : () -> ()
"spirv.GlobalVariable"() <{binding = 1 : i32, descriptor_set = 0 : i32, sym_name = "__resource_var_0_1_", type = !spirv.ptr<none, StorageBuffer>}> : () -> ()
"spirv.GlobalVariable"() <{binding = 2 : i32, descriptor_set = 0 : i32, sym_name = "__resource_var_0_2_", type = !spirv.ptr<none, StorageBuffer>}> : () -> ()
"func.func"() <{function_type = () -> (), sym_name = "main$async_dispatch_0_matmul_1x3x4_f32"}> ({
%0 = "arith.constant"() <{value = 4 : index}> : () -> index
%1 = "arith.constant"() <{value = 1 : index}> : () -> index
%2 = "arith.constant"() <{value = 3 : index}> : () -> index
%3 = "arith.constant"() <{value = 0 : index}> : () -> index
%4 = "arith.constant"() <{value = 0.000000e+00 : f32}> : () -> f32
%5 = "arith.constant"() <{value = dense_resource<torch_tensor_4_3_torch.float32> : tensor<4x3xf32>}> : () -> tensor<4x3xf32>
%6 = "hal.interface.binding.subspan"(%3, %0) {alignment = 64 : index, binding = 0 : index, descriptor_flags = 1 : i32, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 1>, set = 0 : index} : (index, index) -> memref<?xf32, #spirv.storage_class<StorageBuffer>>
%7 = "hal.interface.binding.subspan"(%3, %2) {alignment = 64 : index, binding = 1 : index, descriptor_flags = 1 : i32, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 1>, set = 0 : index} : (index, index) -> memref<?xf32, #spirv.storage_class<StorageBuffer>>
%8 = "hal.interface.binding.subspan"(%3, %2) {alignment = 64 : index, binding = 2 : index, descriptor_type = #hal.descriptor_type<storage_buffer>, operandSegmentSizes = array<i32: 1, 1>, set = 0 : index} : (index, index) -> memref<?xf32, #spirv.storage_class<StorageBuffer>>
%9 = "memref.alloc"() <{operandSegmentSizes = array<i32: 0, 0>}> : () -> memref<3xf32, #spirv.storage_class<Workgroup>>
%10 = "gpu.thread_id"() <{dimension = #gpu<dim x>}> : () -> index
%11 = "gpu.block_dim"() <{dimension = #gpu<dim x>}> : () -> index
"scf.for"(%10, %2, %11) ({
^bb0(%arg0: index):
"memref.store"(%4, %9, %arg0) <{nontemporal = false}> : (f32, memref<3xf32, #spirv.storage_class<Workgroup>>, index) -> ()
"scf.yield"() : () -> ()
}) : (index, index, index) -> ()
"scf.for"(%10, %2, %11) ({
^bb0(%arg0: index):
"scf.for"(%3, %0, %1) ({
^bb0(%arg1: index):
%12 = "memref.load"(%6, %arg1) <{nontemporal = false}> : (memref<?xf32, #spirv.storage_class<StorageBuffer>>, index) -> f32
%13 = "tensor.extract"(%5, %arg1, %arg0) : (tensor<4x3xf32>, index, index) -> f32
%14 = "memref.load"(%9, %arg0) <{nontemporal = false}> : (memref<3xf32, #spirv.storage_class<Workgroup>>, index) -> f32
%15 = "arith.mulf"(%12, %13) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
%16 = "arith.addf"(%14, %15) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
"memref.store"(%16, %9, %arg0) <{nontemporal = false}> : (f32, memref<3xf32, #spirv.storage_class<Workgroup>>, index) -> ()
"scf.yield"() : () -> ()
}) : (index, index, index) -> ()
"scf.yield"() : () -> ()
}) : (index, index, index) -> ()
"scf.for"(%10, %2, %11) ({
^bb0(%arg0: index):
%12 = "memref.load"(%9, %arg0) <{nontemporal = false}> : (memref<3xf32, #spirv.storage_class<Workgroup>>, index) -> f32
%13 = "memref.load"(%7, %arg0) <{nontemporal = false}> : (memref<?xf32, #spirv.storage_class<StorageBuffer>>, index) -> f32
%14 = "arith.addf"(%12, %13) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
"memref.store"(%14, %8, %arg0) <{nontemporal = false}> : (f32, memref<?xf32, #spirv.storage_class<StorageBuffer>>, index) -> ()
"scf.yield"() : () -> ()
}) : (index, index, index) -> ()
"func.return"() : () -> ()
}) {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 1, 1]>} : () -> ()
}) {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, Int8, Int16, Int64, Float16, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess, StoragePushConstant8, StorageUniform16, StorageBuffer16BitAccess, StoragePushConstant16, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, StoragePushConstant16, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Metal, Apple:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024 : i32, 1024 : i32, 1024 : i32]>>} : () -> ()
"hal.executable.variant_end"() : () -> ()
}) {sym_name = "metal_msl_fb", target = #hal.executable.target<"metal-spirv", "metal-msl-fb", {spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Shader, Int8, Int16, Int64, Float16, UniformAndStorageBuffer8BitAccess, StorageBuffer8BitAccess, StoragePushConstant8, StorageUniform16, StorageBuffer16BitAccess, StoragePushConstant16, GroupNonUniform, GroupNonUniformVote, GroupNonUniformArithmetic, GroupNonUniformBallot, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformQuad, StoragePushConstant16, VariablePointers, VariablePointersStorageBuffer], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_storage_buffer_storage_class, SPV_KHR_variable_pointers]>, api=Metal, Apple:IntegratedGPU, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 1024, max_compute_workgroup_size = [1024 : i32, 1024 : i32, 1024 : i32]>>}>} : () -> ()
It looks like the dense_resource fails, but I don't have enough experience with SPIR-V to fix it myself...
Steps to reproduce your issue
- Run
iree-compile --iree-hal-target-backends=metal-spirvon the input IR.
What component(s) does this issue relate to?
Compiler
Version information
Installed via pip and should be the newest release
IREE (https://iree.dev):
IREE compiler version 20240410.859 @ b4273a4bfc66ba6dd8f62f6483d74d42a7b936f1
LLVM version 19.0.0git
Optimized build
Additional context
The metal docs on https://iree.dev/guides/deployment-configurations/gpu-metal/ are also missing, would be nice if a very simple example with the correct flags would be available... Had to look into the tests for them right now
Version information
Installed via pip and should be the newest release
IREE (https://iree.dev): IREE compiler version 20240410.859 @ b4273a4bfc66ba6dd8f62f6483d74d42a7b936f1 LLVM version 19.0.0git Optimized build
That is the latest stable release. Might want to also try the latest nightly (https://iree.dev/reference/bindings/python/#__tabbed_2_2):
python -m pip install --find-links https://iree.dev/pip-release-links.html --upgrade iree-compiler
failed to translate executables
:0: error: failed to legalize operation 'arith.constant'
I saw a similar error on https://github.com/iree-org/iree/issues/17137 that was fixed by https://github.com/iree-org/iree/pull/17121. I'm not too confident that will also fix your issue here though, since that error message is quite generic for how far down in the compilation pipeline it occurs. Maybe someone more familiar with the workings of code generation could weigh in.
Additional context
The metal docs on https://iree.dev/guides/deployment-configurations/gpu-metal/ are also missing, would be nice if a very simple example with the correct flags would be available... Had to look into the tests for them right now
fyi @antiagainst
That is the latest stable release. Might want to also try the latest nightly (https://iree.dev/reference/bindings/python/#__tabbed_2_2):
This does not fix the error unfortunately... If this is related to https://github.com/iree-org/iree/issues/17137 then this seems like a general SPIR-V bug and not something specific to metal.
I created https://github.com/llvm/llvm-project/pull/91318 to fix this
@ScottTodd My PR fixed this error. Maybe you could try https://github.com/iree-org/iree/issues/17137 to see if it is also fixed 🙂
@ScottTodd My PR fixed this error. Maybe you could try #17137 to see if it is also fixed 🙂
Ah, thanks for the note. The test failures on that other issue are still there (they are marked "expected to fail" / XFAIL and will block CI runs if they happen to start passing): https://github.com/iree-org/iree/actions/runs/9260597150/job/25474870074#step:9:86