[rocm] Wan2.1 layer norm dispatch causes runtime semaphore abort
What happened?
Dispatch (benchmark) IR (wan_dispatch_7.mlir):
Show/Hide
module {
util.global private @__device_0 = #hal.device.target<"hip", [#hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, iree_codegen.default_tuning_spec = #rocm.builtin.tuning_module<"iree_default_tuning_spec_gfx942.mlir">, ukernels = "none"}>]> : !hal.device
hal.executable private @forward_t2v_bs1$async_dispatch_7 {
hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, iree_codegen.default_tuning_spec = #rocm.builtin.tuning_module<"iree_default_tuning_spec_gfx942.mlir">, ukernels = "none"}>) {
hal.executable.export public @forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32 ordinal(0) layout(#hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) count(%arg0: !hal.device) -> (index, index, index) {
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice
hal.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>} {
%cst = arith.constant 5.120000e+03 : f32
%cst_0 = arith.constant 9.99999997E-7 : f32
%c1548590400 = arith.constant 1548590400 : index
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c1548590400) flags("ReadOnly|Indirect") {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<readonly:tensor<75600x5120xf32>>
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<readonly:tensor<75600xf32>>
%2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags(Indirect) {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<75600x5120xf32>>
%3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [75600, 5120], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<75600x5120xf32>> -> tensor<75600x5120xf32>
%4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0], sizes = [75600], strides = [1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<75600xf32>> -> tensor<75600xf32>
%5 = tensor.empty() : tensor<75600x5120xf32>
%6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%3, %4 : tensor<75600x5120xf32>, tensor<75600xf32>) outs(%5 : tensor<75600x5120xf32>) attrs = {lowering_config = #iree_gpu.lowering_config<{thread = [1, 4], workgroup = [1, 256]}>} {
^bb0(%in: f32, %in_1: f32, %out: f32):
%7 = arith.divf %in_1, %cst : f32
%8 = arith.addf %7, %cst_0 : f32
%9 = math.rsqrt %8 : f32
%10 = arith.mulf %in, %9 : f32
linalg.yield %10 : f32
} -> tensor<75600x5120xf32>
iree_tensor_ext.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [75600, 5120], strides = [1, 1] : tensor<75600x5120xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<75600x5120xf32>>
return
}
}
}
}
util.global private mutable @forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32_buffer : !hal.buffer
util.initializer {
%device, %queue_affinity = hal.device.resolve on(#hal.device.affinity<@__device_0>) : !hal.device, i64
%allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator
%memory_type = hal.memory_type<"DeviceVisible|DeviceLocal"> : i32
%buffer_usage = hal.buffer_usage<"TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage"> : i32
%c4645166592 = arith.constant 4645166592 : index
%buffer = hal.allocator.allocate<%allocator : !hal.allocator> affinity(%queue_affinity) type(%memory_type) usage(%buffer_usage) : !hal.buffer{%c4645166592}
util.global.store %buffer, @forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32_buffer : !hal.buffer
util.return
}
util.func public @forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32(%arg0: i32) attributes {iree.abi.stub, iree.reflection = {iree.benchmark = "dispatch"}} {
%0 = arith.index_cast %arg0 : i32 to index
%device, %queue_affinity = hal.device.resolve on(#hal.device.affinity<@__device_0>) : !hal.device, i64
%cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot|AllowInlineExecution") categories(Dispatch) affinity(%queue_affinity) : !hal.command_buffer
%forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32_buffer = util.global.load @forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32_buffer : !hal.buffer
%c0 = arith.constant 0 : index
%c3096878400 = arith.constant 3096878400 : index
%c3096878592 = arith.constant 3096878592 : index
%c1548288000 = arith.constant 1548288000 : index
%workgroup_x, %workgroup_y, %workgroup_z = hal.executable.calculate_workgroups device(%device : !hal.device) target(@forward_t2v_bs1$async_dispatch_7::@rocm_hsaco_fb::@forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32) : index, index, index
%exe = hal.executable.lookup device(%device : !hal.device) executable(@forward_t2v_bs1$async_dispatch_7) : !hal.executable
%ordinal = hal.executable.export.ordinal target(@forward_t2v_bs1$async_dispatch_7::@rocm_hsaco_fb::@forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32) : index
%c1 = arith.constant 1 : index
scf.for %arg1 = %c0 to %0 step %c1 {
hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe : !hal.executable)[%ordinal] workgroups([%workgroup_x, %workgroup_y, %workgroup_z]) bindings([
(%forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32_buffer : !hal.buffer)[%c0, %c3096878400],
(%forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32_buffer : !hal.buffer)[%c3096878592, %c1548288000]
]) flags("None")
hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None")
}
hal.command_buffer.finalize<%cmd : !hal.command_buffer>
%1 = util.null : !hal.fence
%fence = hal.fence.create device(%device : !hal.device) flags("None") : !hal.fence
hal.device.queue.execute<%device : !hal.device> affinity(%queue_affinity) wait(%1) signal(%fence) commands(%cmd) flags("None")
%c-1_i32 = arith.constant -1 : i32
%status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
util.status.check_ok %status, "failed to wait on timepoint"
util.return
}
}
Compile command:
iree-compile --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-execution-model=async-external wan_dispatch_7.mlir -o wan_dispatch_7_gfx942.vmfb
Runtime command:
iree-benchmark-module --module=wan_dispatch_7_gfx942.vmfb --device=hip
Runtime output:
c/runtime/src/iree/hal/drivers/hip/event_semaphore.c:786: ABORTED; the semaphore was aborted; while invoking native function hal.fence.await; while calling import;
[ 0] bytecode module.forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32:386
Steps to reproduce your issue
- Create wan_dispatch_7.mlir and paste in the IR provided at top of this issue.
- Build latest top of IREE main (iree-org/iree@e5b65b533fec33c94e5cc7a5079b1c51131d46a7 when issue was encountered) or pip install via
pip install --pre iree-base-compiler iree-base-runtime -f https://iree.dev/pip-release-links.html - Run compile command.
- Run benchmark command.
What component(s) does this issue relate to?
Runtime, Compiler
Version information
IREE (https://iree.dev): IREE compiler version 3.6.0rc20250619 @ e22bed8c615a5b27673d79f067a29533a6848dda LLVM version 21.0.0git Optimized build
also reproduced with a debug source build @ 6c6fa9578acb3955f627e1c784e441de9a30ef1c
Additional context
Full model IR: https://gist.github.com/monorimet/2ba3058b19e1fe6267820cc0b339f35a
Dispatch print-after-all output: https://sharkpublic.blob.core.windows.net/sharkpublic/ean/wan_ln_out.txt
A wild guess is that this is a bad fusion, because creating an export for just the layer norm op with identical inputs/params does not encounter the runtime abort. i.e., the following MLIR compiles and runs successfully:
module @module {
func.func @main(%arg0: !torch.vtensor<[1,75600,5120],f32>) -> !torch.vtensor<[1,75600,5120],f32> attributes {torch.assume_strict_symbolic_shapes} {
%int5120 = torch.constant.int 5120
%0 = torch.prim.ListConstruct %int5120 : (!torch.int) -> !torch.list<int>
%none = torch.constant.none
%none_0 = torch.constant.none
%float9.999990e-07 = torch.constant.float 9.9999999999999995E-7
%true = torch.constant.bool true
%1 = torch.aten.layer_norm %arg0, %0, %none, %none_0, %float9.999990e-07, %true : !torch.vtensor<[1,75600,5120],f32>, !torch.list<int>, !torch.none, !torch.none, !torch.float, !torch.bool -> !torch.vtensor<[1,75600,5120],f32>
return %1 : !torch.vtensor<[1,75600,5120],f32>
}
}
I was able to minimize further with the following torch IR, which includes a few operations on the norm input tensor that appear in the model:
module @module {
func.func @main(%arg0: !torch.vtensor<[1,5120,21,45,80],bf16>) -> !torch.vtensor<[1,75600,5120],bf16> attributes {torch.assume_strict_symbolic_shapes} {
%int1 = torch.constant.int 1
%int5120 = torch.constant.int 5120
%int75600 = torch.constant.int 75600
%0 = torch.prim.ListConstruct %int1, %int5120, %int75600 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,5120,21,45,80],bf16>, !torch.list<int> -> !torch.vtensor<[1,5120,75600],bf16>
%int1_0 = torch.constant.int 1
%int2 = torch.constant.int 2
%2 = torch.aten.transpose.int %1, %int1_0, %int2 : !torch.vtensor<[1,5120,75600],bf16>, !torch.int, !torch.int -> !torch.vtensor<[1,75600,5120],bf16>
%int1_1 = torch.constant.int 1
%int0 = torch.constant.int 0
%int5120_2 = torch.constant.int 5120
%3 = torch.prim.ListConstruct %int1_1, %int0, %int5120_2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
%none = torch.constant.none
%none_3 = torch.constant.none
%none_4 = torch.constant.none
%false = torch.constant.bool false
%4 = torch.aten.new_zeros %2, %3, %none, %none_3, %none_4, %false : !torch.vtensor<[1,75600,5120],bf16>, !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,0,5120],bf16>
%5 = torch.prim.ListConstruct %2, %4 : (!torch.vtensor<[1,75600,5120],bf16>, !torch.vtensor<[1,0,5120],bf16>) -> !torch.list<vtensor>
%int1_5 = torch.constant.int 1
%6 = torch.aten.cat %5, %int1_5 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,75600,5120],bf16>
%7 = torch.prim.ListConstruct %6 : (!torch.vtensor<[1,75600,5120],bf16>) -> !torch.list<vtensor>
%int0_6 = torch.constant.int 0
%8 = torch.aten.cat %7, %int0_6 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,75600,5120],bf16>
%int5120_7 = torch.constant.int 5120
%9 = torch.prim.ListConstruct %int5120_7 : (!torch.int) -> !torch.list<int>
%none_8 = torch.constant.none
%none_9 = torch.constant.none
%float9.999990e-07 = torch.constant.float 9.9999999999999995E-7
%true = torch.constant.bool true
%10 = torch.aten.layer_norm %8, %9, %none_8, %none_9, %float9.999990e-07, %true : !torch.vtensor<[1,75600,5120],bf16>, !torch.list<int>, !torch.none, !torch.none, !torch.float, !torch.bool -> !torch.vtensor<[1,75600,5120],bf16>
return %10 : !torch.vtensor<[1,75600,5120],bf16>
}
}
hopefully this leads to a root cause soon.
It seems to be an issue with a transpose on the norm input -- the abort also goes away when I increase --iree-opt-level to O2 or higher. Perhaps we are relying on some advanced fusion techniques or transpose propagations to handle these transposes correctly?
@AWoloszyn if you have cycles for it, your context on HIP runtime may help here the most. PTAL
Output with AMD_LOG_LEVEL=3: Azure
gridDimY must be at MOST std::numeric_limits<uint16_t>::max() == 65535 It is failing to launch the kernel as it has a gridDimY of 75600
yep! rocminfo will show you for your particular device, but nowadays the norm is:
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
The compiler needs to ensure it doesn't produce big Y/Z dims.
@benvanik Is there somewhere we could validate this at compile-time for AMD backends? It has come up a couple times.
It's not possible reliably in the compiler (if there are dynamic shapes there's nothing we can do, if the values are derived from runtime queries, etc). The HIP HAL could check and return an error during the HIP command buffer dispatch record - that won't work for indirect dispatches but may be good enough for most uses today. The best "solution" is to change codegen to bias towards putting more work in X (failures will still happen, but there are clearly cases today where X is getting fewer work items than YZ). Int range analysis could help codegen decide when to do that (as often we roughly know dimensions. I suppose we could add checks if it's possible for dynamic values to result in larger sizes, but probably only in < O2 or something.
Thank you @AWoloszyn, reducing the sequence length (in this model that means restricting height and width of output) to below 65535 does circumvent the error, though it is obviously not a solution and only corroborates the identification of root cause.