iree icon indicating copy to clipboard operation
iree copied to clipboard

[rocm] Wan2.1 layer norm dispatch causes runtime semaphore abort

Open monorimet opened this issue 6 months ago • 2 comments

What happened?

Dispatch (benchmark) IR (wan_dispatch_7.mlir):

Show/Hide
module {
  util.global private @__device_0 = #hal.device.target<"hip", [#hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, iree_codegen.default_tuning_spec = #rocm.builtin.tuning_module<"iree_default_tuning_spec_gfx942.mlir">, ukernels = "none"}>]> : !hal.device
  hal.executable private @forward_t2v_bs1$async_dispatch_7 {
    hal.executable.variant public @rocm_hsaco_fb target(<"rocm", "rocm-hsaco-fb", {abi = "hip", iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, iree_codegen.default_tuning_spec = #rocm.builtin.tuning_module<"iree_default_tuning_spec_gfx942.mlir">, ukernels = "none"}>) {
      hal.executable.export public @forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32 ordinal(0) layout(#hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) count(%arg0: !hal.device) -> (index, index, index) {
        %x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice
        hal.return %x, %y, %z : index, index, index
      }
      builtin.module {
        func.func @forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 1, 1] subgroup_size = 64>} {
          %cst = arith.constant 5.120000e+03 : f32
          %cst_0 = arith.constant 9.99999997E-7 : f32
          %c1548590400 = arith.constant 1548590400 : index
          %c0 = arith.constant 0 : index
          %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c1548590400) flags("ReadOnly|Indirect") {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<readonly:tensor<75600x5120xf32>>
          %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<readonly:tensor<75600xf32>>
          %2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags(Indirect) {iree_gpu.use_rocdl_buffer_instructions} : !iree_tensor_ext.dispatch.tensor<writeonly:tensor<75600x5120xf32>>
          %3 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0, 0], sizes = [75600, 5120], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<75600x5120xf32>> -> tensor<75600x5120xf32>
          %4 = iree_tensor_ext.dispatch.tensor.load %1, offsets = [0], sizes = [75600], strides = [1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<75600xf32>> -> tensor<75600xf32>
          %5 = tensor.empty() : tensor<75600x5120xf32>
          %6 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%3, %4 : tensor<75600x5120xf32>, tensor<75600xf32>) outs(%5 : tensor<75600x5120xf32>) attrs =  {lowering_config = #iree_gpu.lowering_config<{thread = [1, 4], workgroup = [1, 256]}>} {
          ^bb0(%in: f32, %in_1: f32, %out: f32):
            %7 = arith.divf %in_1, %cst : f32
            %8 = arith.addf %7, %cst_0 : f32
            %9 = math.rsqrt %8 : f32
            %10 = arith.mulf %in, %9 : f32
            linalg.yield %10 : f32
          } -> tensor<75600x5120xf32>
          iree_tensor_ext.dispatch.tensor.store %6, %2, offsets = [0, 0], sizes = [75600, 5120], strides = [1, 1] : tensor<75600x5120xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<75600x5120xf32>>
          return
        }
      }
    }
  }
  util.global private mutable @forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32_buffer : !hal.buffer
  util.initializer {
    %device, %queue_affinity = hal.device.resolve on(#hal.device.affinity<@__device_0>) : !hal.device, i64
    %allocator = hal.device.allocator<%device : !hal.device> : !hal.allocator
    %memory_type = hal.memory_type<"DeviceVisible|DeviceLocal"> : i32
    %buffer_usage = hal.buffer_usage<"TransferSource|TransferTarget|Transfer|DispatchStorageRead|DispatchStorageWrite|DispatchStorage"> : i32
    %c4645166592 = arith.constant 4645166592 : index
    %buffer = hal.allocator.allocate<%allocator : !hal.allocator> affinity(%queue_affinity) type(%memory_type) usage(%buffer_usage) : !hal.buffer{%c4645166592}
    util.global.store %buffer, @forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32_buffer : !hal.buffer
    util.return
  }
  util.func public @forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32(%arg0: i32) attributes {iree.abi.stub, iree.reflection = {iree.benchmark = "dispatch"}} {
    %0 = arith.index_cast %arg0 : i32 to index
    %device, %queue_affinity = hal.device.resolve on(#hal.device.affinity<@__device_0>) : !hal.device, i64
    %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot|AllowInlineExecution") categories(Dispatch) affinity(%queue_affinity) : !hal.command_buffer
    %forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32_buffer = util.global.load @forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32_buffer : !hal.buffer
    %c0 = arith.constant 0 : index
    %c3096878400 = arith.constant 3096878400 : index
    %c3096878592 = arith.constant 3096878592 : index
    %c1548288000 = arith.constant 1548288000 : index
    %workgroup_x, %workgroup_y, %workgroup_z = hal.executable.calculate_workgroups device(%device : !hal.device) target(@forward_t2v_bs1$async_dispatch_7::@rocm_hsaco_fb::@forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32) : index, index, index
    %exe = hal.executable.lookup device(%device : !hal.device) executable(@forward_t2v_bs1$async_dispatch_7) : !hal.executable
    %ordinal = hal.executable.export.ordinal target(@forward_t2v_bs1$async_dispatch_7::@rocm_hsaco_fb::@forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32) : index
    %c1 = arith.constant 1 : index
    scf.for %arg1 = %c0 to %0 step %c1 {
      hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%exe : !hal.executable)[%ordinal] workgroups([%workgroup_x, %workgroup_y, %workgroup_z]) bindings([
        (%forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32_buffer : !hal.buffer)[%c0, %c3096878400], 
        (%forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32_buffer : !hal.buffer)[%c3096878592, %c1548288000]
      ]) flags("None")
      hal.command_buffer.execution_barrier<%cmd : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None")
    }
    hal.command_buffer.finalize<%cmd : !hal.command_buffer>
    %1 = util.null : !hal.fence
    %fence = hal.fence.create device(%device : !hal.device) flags("None") : !hal.fence
    hal.device.queue.execute<%device : !hal.device> affinity(%queue_affinity) wait(%1) signal(%fence) commands(%cmd) flags("None")
    %c-1_i32 = arith.constant -1 : i32
    %status = hal.fence.await until([%fence]) timeout_millis(%c-1_i32) flags("None") : i32
    util.status.check_ok %status, "failed to wait on timepoint"
    util.return
  }
}

Compile command:

iree-compile --iree-hal-target-backends=rocm --iree-hip-target=gfx942 --iree-execution-model=async-external  wan_dispatch_7.mlir  -o wan_dispatch_7_gfx942.vmfb

Runtime command:

iree-benchmark-module --module=wan_dispatch_7_gfx942.vmfb --device=hip

Runtime output:

c/runtime/src/iree/hal/drivers/hip/event_semaphore.c:786: ABORTED; the semaphore was aborted; while invoking native function hal.fence.await; while calling import; 
[ 0] bytecode module.forward_t2v_bs1$async_dispatch_7_rocm_hsaco_fb_forward_t2v_bs1$async_dispatch_7_elementwise_75600x5120_f32:386

Steps to reproduce your issue

  1. Create wan_dispatch_7.mlir and paste in the IR provided at top of this issue.
  2. Build latest top of IREE main (iree-org/iree@e5b65b533fec33c94e5cc7a5079b1c51131d46a7 when issue was encountered) or pip install via pip install --pre iree-base-compiler iree-base-runtime -f https://iree.dev/pip-release-links.html
  3. Run compile command.
  4. Run benchmark command.

What component(s) does this issue relate to?

Runtime, Compiler

Version information

IREE (https://iree.dev): IREE compiler version 3.6.0rc20250619 @ e22bed8c615a5b27673d79f067a29533a6848dda LLVM version 21.0.0git Optimized build

also reproduced with a debug source build @ 6c6fa9578acb3955f627e1c784e441de9a30ef1c

Additional context

Full model IR: https://gist.github.com/monorimet/2ba3058b19e1fe6267820cc0b339f35a

Dispatch print-after-all output: https://sharkpublic.blob.core.windows.net/sharkpublic/ean/wan_ln_out.txt

A wild guess is that this is a bad fusion, because creating an export for just the layer norm op with identical inputs/params does not encounter the runtime abort. i.e., the following MLIR compiles and runs successfully:

module @module {
  func.func @main(%arg0: !torch.vtensor<[1,75600,5120],f32>) -> !torch.vtensor<[1,75600,5120],f32> attributes {torch.assume_strict_symbolic_shapes} {
    %int5120 = torch.constant.int 5120
    %0 = torch.prim.ListConstruct %int5120 : (!torch.int) -> !torch.list<int>
    %none = torch.constant.none
    %none_0 = torch.constant.none
    %float9.999990e-07 = torch.constant.float 9.9999999999999995E-7
    %true = torch.constant.bool true
    %1 = torch.aten.layer_norm %arg0, %0, %none, %none_0, %float9.999990e-07, %true : !torch.vtensor<[1,75600,5120],f32>, !torch.list<int>, !torch.none, !torch.none, !torch.float, !torch.bool -> !torch.vtensor<[1,75600,5120],f32>
    return %1 : !torch.vtensor<[1,75600,5120],f32>
  }
}

monorimet avatar Jun 20 '25 19:06 monorimet

I was able to minimize further with the following torch IR, which includes a few operations on the norm input tensor that appear in the model:

module @module {
  func.func @main(%arg0: !torch.vtensor<[1,5120,21,45,80],bf16>) -> !torch.vtensor<[1,75600,5120],bf16> attributes {torch.assume_strict_symbolic_shapes} {
    %int1 = torch.constant.int 1
    %int5120 = torch.constant.int 5120
    %int75600 = torch.constant.int 75600
    %0 = torch.prim.ListConstruct %int1, %int5120, %int75600 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %1 = torch.aten.view %arg0, %0 : !torch.vtensor<[1,5120,21,45,80],bf16>, !torch.list<int> -> !torch.vtensor<[1,5120,75600],bf16>
    %int1_0 = torch.constant.int 1
    %int2 = torch.constant.int 2
    %2 = torch.aten.transpose.int %1, %int1_0, %int2 : !torch.vtensor<[1,5120,75600],bf16>, !torch.int, !torch.int -> !torch.vtensor<[1,75600,5120],bf16>
    %int1_1 = torch.constant.int 1
    %int0 = torch.constant.int 0
    %int5120_2 = torch.constant.int 5120
    %3 = torch.prim.ListConstruct %int1_1, %int0, %int5120_2 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %none = torch.constant.none
    %none_3 = torch.constant.none
    %none_4 = torch.constant.none
    %false = torch.constant.bool false
    %4 = torch.aten.new_zeros %2, %3, %none, %none_3, %none_4, %false : !torch.vtensor<[1,75600,5120],bf16>, !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1,0,5120],bf16>
    %5 = torch.prim.ListConstruct %2, %4 : (!torch.vtensor<[1,75600,5120],bf16>, !torch.vtensor<[1,0,5120],bf16>) -> !torch.list<vtensor>
    %int1_5 = torch.constant.int 1
    %6 = torch.aten.cat %5, %int1_5 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,75600,5120],bf16>
    %7 = torch.prim.ListConstruct %6 : (!torch.vtensor<[1,75600,5120],bf16>) -> !torch.list<vtensor>
    %int0_6 = torch.constant.int 0
    %8 = torch.aten.cat %7, %int0_6 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[1,75600,5120],bf16>
    %int5120_7 = torch.constant.int 5120
    %9 = torch.prim.ListConstruct %int5120_7 : (!torch.int) -> !torch.list<int>
    %none_8 = torch.constant.none
    %none_9 = torch.constant.none
    %float9.999990e-07 = torch.constant.float 9.9999999999999995E-7
    %true = torch.constant.bool true
    %10 = torch.aten.layer_norm %8, %9, %none_8, %none_9, %float9.999990e-07, %true : !torch.vtensor<[1,75600,5120],bf16>, !torch.list<int>, !torch.none, !torch.none, !torch.float, !torch.bool -> !torch.vtensor<[1,75600,5120],bf16>
    return %10 : !torch.vtensor<[1,75600,5120],bf16>
  }
}

hopefully this leads to a root cause soon.

monorimet avatar Jun 20 '25 19:06 monorimet

It seems to be an issue with a transpose on the norm input -- the abort also goes away when I increase --iree-opt-level to O2 or higher. Perhaps we are relying on some advanced fusion techniques or transpose propagations to handle these transposes correctly?

monorimet avatar Jun 20 '25 20:06 monorimet

@AWoloszyn if you have cycles for it, your context on HIP runtime may help here the most. PTAL

monorimet avatar Jun 24 '25 17:06 monorimet

Output with AMD_LOG_LEVEL=3: Azure

monorimet avatar Jun 24 '25 17:06 monorimet

gridDimY must be at MOST std::numeric_limits<uint16_t>::max() == 65535 It is failing to launch the kernel as it has a gridDimY of 75600

AWoloszyn avatar Jun 24 '25 17:06 AWoloszyn

yep! rocminfo will show you for your particular device, but nowadays the norm is:

      Grid Max Size:           4294967295(0xffffffff)             
      Grid Max Size per Dimension:
        x                        2147483647(0x7fffffff)             
        y                        65535(0xffff)                      
        z                        65535(0xffff)      

The compiler needs to ensure it doesn't produce big Y/Z dims.

benvanik avatar Jun 24 '25 17:06 benvanik

@benvanik Is there somewhere we could validate this at compile-time for AMD backends? It has come up a couple times.

AWoloszyn avatar Jun 24 '25 17:06 AWoloszyn

It's not possible reliably in the compiler (if there are dynamic shapes there's nothing we can do, if the values are derived from runtime queries, etc). The HIP HAL could check and return an error during the HIP command buffer dispatch record - that won't work for indirect dispatches but may be good enough for most uses today. The best "solution" is to change codegen to bias towards putting more work in X (failures will still happen, but there are clearly cases today where X is getting fewer work items than YZ). Int range analysis could help codegen decide when to do that (as often we roughly know dimensions. I suppose we could add checks if it's possible for dynamic values to result in larger sizes, but probably only in < O2 or something.

benvanik avatar Jun 24 '25 17:06 benvanik

Thank you @AWoloszyn, reducing the sequence length (in this model that means restricting height and width of output) to below 65535 does circumvent the error, though it is obviously not a solution and only corroborates the identification of root cause.

monorimet avatar Jun 24 '25 20:06 monorimet