iree icon indicating copy to clipboard operation
iree copied to clipboard

[DT][gfx950] Bring up llama8b _fp8 on mi350

Open Abhishek-Varma opened this issue 3 months ago • 16 comments

  1. Take the input IR from here.
  2. Run the following command with IREE commit 952da4192af :-
iree-compile ~/input.mlir \
--iree-hal-target-backends=rocm --iree-hip-target=gfx950 --iree-hal-target-device=hip \
--iree-opt-level=O3 --iree-dispatch-creation-propagate-collapse-across-expands=true \
--iree-codegen-enable-default-tuning-specs=true --iree-hip-enable-tensor-ukernels \
--iree-hal-indirect-command-buffers=true --iree-stream-resource-memory-model=discrete \
--iree-hip-specialize-dispatches --iree-hal-memoization=true --iree-opt-data-tiling=false \
--iree-dispatch-creation-data-tiling --iree-hip-encoding-layout-resolver=data-tiling 

The above gives the following error as part of GPUDistributePass :-

error: unsupported non-normalized loops

NOTE: The input dispatch was obtained via Llama 8B f8e4m3fn on gfx950 - after disabling a check which prevents fusion of encoding ops for producers with multiple uses.

Steps to resolve :-

  1. Replace GPUDistributePass here with GPUDistributeForAllPass.
  2. Add NormalizeLoopBoundsPass normalizing scf.forall right above the aforementioned pass.

NOTE: Both resolution steps above are needed and it won't work with just step 2. Else it'd lead to :-

requires statically sized, normalized forall op

Now, the above resolution works because the input dispatch goes through GPUVectorization pipeline instead of GPUTileAndFuse pipeline. (CC: @Max191 )

But ideally we want the above dispatch to indeed go through GPUTileAndFuse pipeline. I tried hardcoding the pipeline to GPUTileAndFuse to see what happens and we get the following error :-

'vector.scatter' op write affecting operations on shared resources are restricted to lane or thread distributed contexts.

I'm adding this as an issue to drive the discussion on what the expectation for handling this dispatch is - can accordingly work towards the fix we all converge on.

CC: @MaheshRavishankar @hanhanW @jtuyls @Max191

Abhishek-Varma avatar Sep 26 '25 15:09 Abhishek-Varma

I think the first step is making it go down to TileAndFuse pipeline. Can you share the IR dump for the vector.scatter op issue? It looks like some operations are not fused into the for loop; I wonder which op is the root operation.

hanhanW avatar Sep 26 '25 19:09 hanhanW

cc @nirvedhmeshram and @qedawkins who might have context on depreciating LLVMGPUVectorize in favor of LLVMGPUTileAndFuse.

Max191 avatar Sep 29 '25 16:09 Max191

I think the first step is making it go down to TileAndFuse pipeline. Can you share the IR dump for the vector.scatter op issue? It looks like some operations are not fused into the for loop; I wonder which op is the root operation.

Here is the IR log. So iree_linalg_ext.map_scatter is not getting fused into the scf.forall (the loop which has the tiled elementwise operation). This is happening because of this - it returns the clone of the original operation. No loop nest is generated.

CC: @hanhanW @MaheshRavishankar

Abhishek-Varma avatar Sep 29 '25 17:09 Abhishek-Varma

Yes, I'd expect that the map_scatter op gets fused into the forall op. It is wrong if it does not happen.

The IR dump looks okay in the first distribution, because the map scatter op is fused:

// -----// IR Dump Before TileAndDistributeToWorkgroupsUsingForallOpPass (iree-codegen-tile-and-distribute-to-workgroups-using-forall-op) //----- //
func.func @prefill_bs4$async_dispatch_0_elementwise_broadcast_Dx4096_i64xf16() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [128, 1, 1] subgroup_size = 64>} {
  %c1 = arith.constant 1 : index
  %true = arith.constant true
  %c128 = arith.constant 128 : index
  %cst = arith.constant 0.000000e+00 : f16
  %c32_i64 = arith.constant 32 : i64
  %c0 = arith.constant 0 : index
  %0 = hal.interface.constant.load layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
  %1 = hal.interface.constant.load layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
  %2 = hal.interface.constant.load layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32
  %3 = arith.extui %1 : i32 to i64
  %4 = arith.shli %3, %c32_i64 : i64
  %5 = arith.extui %0 : i32 to i64
  %6 = arith.ori %5, %4 : i64
  %7 = arith.index_castui %6 : i64 to index
  %8 = arith.index_castui %2 : i32 to index
  %9:2 = util.assume.int 
      %7<umin = 2097152, umax = 8587837440>, 
      %8<umin = 128, umax = 524160, udiv = 128>
    : index, index
  %10 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<128256x4096xf16, #hal.descriptor_type<storage_buffer>>
  %11 = amdgpu.fat_raw_buffer_cast %10 resetOffset : memref<128256x4096xf16, #hal.descriptor_type<storage_buffer>> to memref<128256x4096xf16, #amdgpu.address_space<fat_raw_buffer>>
  %12 = iree_tensor_ext.dispatch.workload.ordinal %9#1, 0 : index
  %13 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<?xi64, #hal.descriptor_type<storage_buffer>>{%12}
  %14 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : memref<?x4096xf16, #hal.descriptor_type<storage_buffer>>{%12}
  %15 = affine.apply affine_map<()[s0] -> (s0 ceildiv 128)>()[%12]
  %16 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(3) alignment(64) offset(%9#0) flags(Indirect) : memref<?x32x4x2x8x4x16x4xf16, strided<[524288, 16384, 4096, 2048, 256, 64, 4, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>{%15}
  %17 = iree_codegen.load_from_buffer %11 : memref<128256x4096xf16, #amdgpu.address_space<fat_raw_buffer>> -> tensor<128256x4096xf16>
  %18 = affine.apply affine_map<()[s0] -> (s0 floordiv 128)>()[%12]
  %19 = tensor.empty(%18) : tensor<?x128x4096xf16>
  %expand_shape = memref.expand_shape %14 [[0, 1], [2]] output_shape [%18, 128, 4096] : memref<?x4096xf16, #hal.descriptor_type<storage_buffer>> into memref<?x128x4096xf16, #hal.descriptor_type<storage_buffer>>
  %expand_shape_0 = memref.expand_shape %13 [[0, 1]] output_shape [%18, 128] : memref<?xi64, #hal.descriptor_type<storage_buffer>> into memref<?x128xi64, #hal.descriptor_type<storage_buffer>>
  %20 = iree_codegen.load_from_buffer %expand_shape_0 : memref<?x128xi64, #hal.descriptor_type<storage_buffer>> -> tensor<?x128xi64>
  %21 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%20 : tensor<?x128xi64>) outs(%19 : tensor<?x128x4096xf16>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 128]]>} {
  ^bb0(%in: i64, %out: f16):
    %26 = linalg.index 2 : index
    %27 = arith.index_cast %in : i64 to index
    %extracted = tensor.extract %17[%27, %26] : tensor<128256x4096xf16>
    linalg.yield %extracted : f16
  } -> tensor<?x128x4096xf16>
  %22 = affine.apply affine_map<()[s0] -> ((s0 ceildiv 128) * 128)>()[%12]
  %23 = arith.divsi %22, %c128 : index
  %24 = tensor.empty(%15) : tensor<?x32x4x2x8x4x16x4xf16>
  %25 = iree_linalg_ext.map_scatter %21 into %24 {
  ^bb0(%arg0: index, %arg1: index, %arg2: index):
    %26:2 = affine.delinearize_index %arg2 into (32, 128) : index, index
    %27:4 = affine.delinearize_index %arg1 into (4, 4, 2, 4) : index, index, index, index
    %28:2 = affine.delinearize_index %26#1 into (16, 8) : index, index
    iree_linalg_ext.yield %arg0, %26#0, %27#0, %27#2, %28#1, %27#1, %28#0, %27#3, %true : index, index, index, index, index, index, index, index, i1
  } : tensor<?x128x4096xf16> into tensor<?x32x4x2x8x4x16x4xf16> -> tensor<?x32x4x2x8x4x16x4xf16>
  iree_codegen.store_to_buffer %21, %expand_shape : tensor<?x128x4096xf16> into memref<?x128x4096xf16, #hal.descriptor_type<storage_buffer>>
  iree_codegen.store_to_buffer %25, %16 : tensor<?x32x4x2x8x4x16x4xf16> into memref<?x32x4x2x8x4x16x4xf16, strided<[524288, 16384, 4096, 2048, 256, 64, 4, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  scf.forall (%arg0, %arg1) = (%12, 0) to (%22, 4096) step (1, 64) {
    %26 = affine.min affine_map<(d0)[s0] -> (d0 + 1, (s0 ceildiv 128) * 128)>(%arg0)[%12]
    %27 = affine.min affine_map<(d0) -> (4096, d0 + 64)>(%arg1)
    scf.forall (%arg2, %arg3) = (%arg0, %arg1) to (%26, %27) step (1, 1) {
      %28 = affine.min affine_map<(d0, d1)[s0] -> (d0 + 1, d1 + 1, (s0 ceildiv 128) * 128)>(%arg2, %arg0)[%12]
      %29 = affine.min affine_map<(d0, d1) -> (4096, d1 + 64, d0 + 1)>(%arg3, %arg1)
      scf.for %arg4 = %arg2 to %28 step %c1 {
        scf.for %arg5 = %arg3 to %29 step %c1 {
          %30:2 = affine.delinearize_index %arg4 into (%23, 128) : index, index
          %31:2 = affine.delinearize_index %arg5 into (32, 128) : index, index
          %32:4 = affine.delinearize_index %30#1 into (4, 4, 2, 4) : index, index, index, index
          %33:2 = affine.delinearize_index %31#1 into (16, 8) : index, index
          memref.store %cst, %16[%30#0, %31#0, %32#0, %32#2, %33#1, %32#1, %33#0, %32#3] : memref<?x32x4x2x8x4x16x4xf16, strided<[524288, 16384, 4096, 2048, 256, 64, 4, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        }
      }
    } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
  } {mapping = [#iree_codegen.workgroup_mapping<x>, #iree_codegen.workgroup_mapping<y>]}
  return
}

// -----// IR Dump Before ConfigTrackingCanonicalizerPass (iree-codegen-config-tracking-canonicalize) //----- //
func.func @prefill_bs4$async_dispatch_0_elementwise_broadcast_Dx4096_i64xf16() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [128, 1, 1] subgroup_size = 64>} {
  %c1 = arith.constant 1 : index
  %true = arith.constant true
  %c128 = arith.constant 128 : index
  %cst = arith.constant 0.000000e+00 : f16
  %c32_i64 = arith.constant 32 : i64
  %c0 = arith.constant 0 : index
  %0 = hal.interface.constant.load layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
  %1 = hal.interface.constant.load layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
  %2 = hal.interface.constant.load layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32
  %3 = arith.extui %1 : i32 to i64
  %4 = arith.shli %3, %c32_i64 : i64
  %5 = arith.extui %0 : i32 to i64
  %6 = arith.ori %5, %4 : i64
  %7 = arith.index_castui %6 : i64 to index
  %8 = arith.index_castui %2 : i32 to index
  %9:2 = util.assume.int 
      %7<umin = 2097152, umax = 8587837440>, 
      %8<umin = 128, umax = 524160, udiv = 128>
    : index, index
  %10 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<128256x4096xf16, #hal.descriptor_type<storage_buffer>>
  %11 = amdgpu.fat_raw_buffer_cast %10 resetOffset : memref<128256x4096xf16, #hal.descriptor_type<storage_buffer>> to memref<128256x4096xf16, #amdgpu.address_space<fat_raw_buffer>>
  %12 = iree_tensor_ext.dispatch.workload.ordinal %9#1, 0 : index
  %13 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<?xi64, #hal.descriptor_type<storage_buffer>>{%12}
  %14 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : memref<?x4096xf16, #hal.descriptor_type<storage_buffer>>{%12}
  %15 = affine.apply affine_map<()[s0] -> (s0 ceildiv 128)>()[%12]
  %16 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(3) alignment(64) offset(%9#0) flags(Indirect) : memref<?x32x4x2x8x4x16x4xf16, strided<[524288, 16384, 4096, 2048, 256, 64, 4, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>{%15}
  %17 = iree_codegen.load_from_buffer %11 : memref<128256x4096xf16, #amdgpu.address_space<fat_raw_buffer>> -> tensor<128256x4096xf16>
  %18 = affine.apply affine_map<()[s0] -> (s0 floordiv 128)>()[%12]
  %19 = tensor.empty(%18) : tensor<?x128x4096xf16>
  %expand_shape = memref.expand_shape %14 [[0, 1], [2]] output_shape [%18, 128, 4096] : memref<?x4096xf16, #hal.descriptor_type<storage_buffer>> into memref<?x128x4096xf16, #hal.descriptor_type<storage_buffer>>
  %expand_shape_0 = memref.expand_shape %13 [[0, 1]] output_shape [%18, 128] : memref<?xi64, #hal.descriptor_type<storage_buffer>> into memref<?x128xi64, #hal.descriptor_type<storage_buffer>>
  %20 = iree_codegen.load_from_buffer %expand_shape_0 : memref<?x128xi64, #hal.descriptor_type<storage_buffer>> -> tensor<?x128xi64>
  %21 = affine.apply affine_map<()[s0] -> ((s0 ceildiv 128) * 128)>()[%12]
  %22 = arith.divsi %21, %c128 : index
  %23 = tensor.empty(%15) : tensor<?x32x4x2x8x4x16x4xf16>
  %24:2 = scf.forall (%arg0, %arg1, %arg2) = (0, 0, 0) to (%18, 128, 4096) step (1, 1, 128) shared_outs(%arg3 = %19, %arg4 = %23) -> (tensor<?x128x4096xf16>, tensor<?x32x4x2x8x4x16x4xf16>) {
    %extracted_slice = tensor.extract_slice %20[%arg0, %arg1] [1, 1] [1, 1] : tensor<?x128xi64> to tensor<1x1xi64>
    %extracted_slice_1 = tensor.extract_slice %arg3[%arg0, %arg1, %arg2] [1, 1, 128] [1, 1, 1] : tensor<?x128x4096xf16> to tensor<1x1x128xf16>
    %25 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice : tensor<1x1xi64>) outs(%extracted_slice_1 : tensor<1x1x128xf16>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 128]]>} {
    ^bb0(%in: i64, %out: f16):
      %27 = linalg.index 2 : index
      %28 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg2)[%27]
      %29 = arith.index_cast %in : i64 to index
      %extracted = tensor.extract %17[%29, %28] : tensor<128256x4096xf16>
      linalg.yield %extracted : f16
    } -> tensor<1x1x128xf16>
    %extracted_slice_2 = tensor.extract_slice %arg4[0, 0, 0, 0, 0, 0, 0, 0] [%15, 32, 4, 2, 8, 4, 16, 4] [1, 1, 1, 1, 1, 1, 1, 1] : tensor<?x32x4x2x8x4x16x4xf16> to tensor<?x32x4x2x8x4x16x4xf16>
    %26 = iree_linalg_ext.map_scatter %25 into %extracted_slice_2 {
    ^bb0(%arg5: index, %arg6: index, %arg7: index):
      %27 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg5, %arg0)
      %28 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg6, %arg1)
      %29 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg7, %arg2)
      %30:2 = affine.delinearize_index %29 into (32, 128) : index, index
      %31:4 = affine.delinearize_index %28 into (4, 4, 2, 4) : index, index, index, index
      %32:2 = affine.delinearize_index %30#1 into (16, 8) : index, index
      iree_linalg_ext.yield %27, %30#0, %31#0, %31#2, %32#1, %31#1, %32#0, %31#3, %true : index, index, index, index, index, index, index, index, i1
    } : tensor<1x1x128xf16> into tensor<?x32x4x2x8x4x16x4xf16> -> tensor<?x32x4x2x8x4x16x4xf16>
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %25 into %arg3[%arg0, %arg1, %arg2] [1, 1, 128] [1, 1, 1] : tensor<1x1x128xf16> into tensor<?x128x4096xf16>
      tensor.parallel_insert_slice %26 into %arg4[0, 0, 0, 0, 0, 0, 0, 0] [%15, 32, 4, 2, 8, 4, 16, 4] [1, 1, 1, 1, 1, 1, 1, 1] : tensor<?x32x4x2x8x4x16x4xf16> into tensor<?x32x4x2x8x4x16x4xf16>
    }
  } {mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
  iree_codegen.store_to_buffer %24#0, %expand_shape : tensor<?x128x4096xf16> into memref<?x128x4096xf16, #hal.descriptor_type<storage_buffer>>
  iree_codegen.store_to_buffer %24#1, %16 : tensor<?x32x4x2x8x4x16x4xf16> into memref<?x32x4x2x8x4x16x4xf16, strided<[524288, 16384, 4096, 2048, 256, 64, 4, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  scf.forall (%arg0, %arg1) = (%12, 0) to (%21, 4096) step (1, 64) {
    %25 = affine.min affine_map<(d0)[s0] -> (d0 + 1, (s0 ceildiv 128) * 128)>(%arg0)[%12]
    %26 = affine.min affine_map<(d0) -> (4096, d0 + 64)>(%arg1)
    scf.forall (%arg2, %arg3) = (%arg0, %arg1) to (%25, %26) step (1, 1) {
      %27 = affine.min affine_map<(d0, d1)[s0] -> (d0 + 1, d1 + 1, (s0 ceildiv 128) * 128)>(%arg2, %arg0)[%12]
      %28 = affine.min affine_map<(d0, d1) -> (4096, d1 + 64, d0 + 1)>(%arg3, %arg1)
      scf.for %arg4 = %arg2 to %27 step %c1 {
        scf.for %arg5 = %arg3 to %28 step %c1 {
          %29:2 = affine.delinearize_index %arg4 into (%22, 128) : index, index
          %30:2 = affine.delinearize_index %arg5 into (32, 128) : index, index
          %31:4 = affine.delinearize_index %29#1 into (4, 4, 2, 4) : index, index, index, index
          %32:2 = affine.delinearize_index %30#1 into (16, 8) : index, index
          memref.store %cst, %16[%29#0, %30#0, %31#0, %31#2, %32#1, %31#1, %32#0, %31#3] : memref<?x32x4x2x8x4x16x4xf16, strided<[524288, 16384, 4096, 2048, 256, 64, 4, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        }
      }
    } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
  } {mapping = [#iree_codegen.workgroup_mapping<x>, #iree_codegen.workgroup_mapping<y>]}
  return
}

I think the issue happens in thread distribution. The op is not fused, like @Abhishek-Varma mentioned. I don't have the full context atm. The tile sizes are 1x1x1 in thread distribution, but the check is verifying if they are zeros or not. @Abhishek-Varma do you know why and how to fix the issue?

(minor note: @Abhishek-Varma please attach the repro next time. Just in case if others would pick it up without getting back to you. It saves communication time.)

/// -----// IR Dump Before GPUGreedilyDistributeToThreadsPass (iree-codegen-gpu-greedily-distribute-to-threads) //----- //
func.func @prefill_bs4$async_dispatch_0_elementwise_broadcast_Dx4096_i64xf16() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [128, 1, 1] subgroup_size = 64>} {
  %c1 = arith.constant 1 : index
  %true = arith.constant true
  %c128 = arith.constant 128 : index
  %cst = arith.constant 0.000000e+00 : f16
  %c32_i64 = arith.constant 32 : i64
  %c0 = arith.constant 0 : index
  %0 = hal.interface.constant.load layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
  %1 = hal.interface.constant.load layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
  %2 = hal.interface.constant.load layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32
  %3 = arith.extui %1 : i32 to i64
  %4 = arith.shli %3, %c32_i64 : i64
  %5 = arith.extui %0 : i32 to i64
  %6 = arith.ori %5, %4 : i64
  %7 = arith.index_castui %6 : i64 to index
  %8 = arith.index_castui %2 : i32 to index
  %9:2 = util.assume.int 
      %7<umin = 2097152, umax = 8587837440>, 
      %8<umin = 128, umax = 524160, udiv = 128>
    : index, index
  %10 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<128256x4096xf16, #hal.descriptor_type<storage_buffer>>
  %11 = amdgpu.fat_raw_buffer_cast %10 resetOffset : memref<128256x4096xf16, #hal.descriptor_type<storage_buffer>> to memref<128256x4096xf16, #amdgpu.address_space<fat_raw_buffer>>
  %12 = iree_tensor_ext.dispatch.workload.ordinal %9#1, 0 : index
  %13 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<?xi64, #hal.descriptor_type<storage_buffer>>{%12}
  %14 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : memref<?x4096xf16, #hal.descriptor_type<storage_buffer>>{%12}
  %15 = affine.apply affine_map<()[s0] -> (s0 ceildiv 128)>()[%12]
  %16 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(3) alignment(64) offset(%9#0) flags(Indirect) : memref<?x32x4x2x8x4x16x4xf16, strided<[524288, 16384, 4096, 2048, 256, 64, 4, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>{%15}
  %17 = iree_codegen.load_from_buffer %11 : memref<128256x4096xf16, #amdgpu.address_space<fat_raw_buffer>> -> tensor<128256x4096xf16>
  %18 = affine.apply affine_map<()[s0] -> (s0 floordiv 128)>()[%12]
  %19 = tensor.empty(%18) : tensor<?x128x4096xf16>
  %expand_shape = memref.expand_shape %14 [[0, 1], [2]] output_shape [%18, 128, 4096] : memref<?x4096xf16, #hal.descriptor_type<storage_buffer>> into memref<?x128x4096xf16, #hal.descriptor_type<storage_buffer>>
  %expand_shape_0 = memref.expand_shape %13 [[0, 1]] output_shape [%18, 128] : memref<?xi64, #hal.descriptor_type<storage_buffer>> into memref<?x128xi64, #hal.descriptor_type<storage_buffer>>
  %20 = iree_codegen.load_from_buffer %expand_shape_0 : memref<?x128xi64, #hal.descriptor_type<storage_buffer>> -> tensor<?x128xi64>
  %21 = affine.apply affine_map<()[s0] -> ((s0 ceildiv 128) * 128)>()[%12]
  %22 = arith.divsi %21, %c128 : index
  %23 = tensor.empty(%15) : tensor<?x32x4x2x8x4x16x4xf16>
  %24:2 = scf.forall (%arg0, %arg1, %arg2) in (%18, 128, 32) shared_outs(%arg3 = %19, %arg4 = %23) -> (tensor<?x128x4096xf16>, tensor<?x32x4x2x8x4x16x4xf16>) {
    %26 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg2)
    %extracted_slice = tensor.extract_slice %arg3[%arg0, %arg1, %26] [1, 1, 128] [1, 1, 1] : tensor<?x128x4096xf16> to tensor<1x1x128xf16>
    %extracted_slice_1 = tensor.extract_slice %20[%arg0, %arg1] [1, 1] [1, 1] : tensor<?x128xi64> to tensor<1x1xi64>
    %27 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_1 : tensor<1x1xi64>) outs(%extracted_slice : tensor<1x1x128xf16>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 128]]>} {
    ^bb0(%in: i64, %out: f16):
      %29 = linalg.index 2 : index
      %30 = affine.apply affine_map<(d0)[s0] -> (d0 * 128 + s0)>(%arg2)[%29]
      %31 = arith.index_cast %in : i64 to index
      %extracted = tensor.extract %17[%31, %30] : tensor<128256x4096xf16>
      linalg.yield %extracted : f16
    } -> tensor<1x1x128xf16>
    %28 = iree_linalg_ext.map_scatter %27 into %arg4 {
    ^bb0(%arg5: index, %arg6: index, %arg7: index):
      %29 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg5, %arg0)
      %30 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg6, %arg1)
      %31 = affine.apply affine_map<(d0, d1) -> (d0 + d1 * 128)>(%arg7, %arg2)
      %32:2 = affine.delinearize_index %31 into (32, 128) : index, index
      %33:4 = affine.delinearize_index %30 into (4, 4, 2, 4) : index, index, index, index
      %34:2 = affine.delinearize_index %32#1 into (16, 8) : index, index
      iree_linalg_ext.yield %29, %32#0, %33#0, %33#2, %34#1, %33#1, %34#0, %33#3, %true : index, index, index, index, index, index, index, index, i1
    } : tensor<1x1x128xf16> into tensor<?x32x4x2x8x4x16x4xf16> -> tensor<?x32x4x2x8x4x16x4xf16>
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %27 into %arg3[%arg0, %arg1, %26] [1, 1, 128] [1, 1, 1] : tensor<1x1x128xf16> into tensor<?x128x4096xf16>
      tensor.parallel_insert_slice %28 into %arg4[0, 0, 0, 0, 0, 0, 0, 0] [%15, 32, 4, 2, 8, 4, 16, 4] [1, 1, 1, 1, 1, 1, 1, 1] : tensor<?x32x4x2x8x4x16x4xf16> into tensor<?x32x4x2x8x4x16x4xf16>
    }
  } {mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
  iree_codegen.store_to_buffer %24#0, %expand_shape : tensor<?x128x4096xf16> into memref<?x128x4096xf16, #hal.descriptor_type<storage_buffer>>
  iree_codegen.store_to_buffer %24#1, %16 : tensor<?x32x4x2x8x4x16x4xf16> into memref<?x32x4x2x8x4x16x4xf16, strided<[524288, 16384, 4096, 2048, 256, 64, 4, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %25 = affine.apply affine_map<()[s0] -> (-s0 + (s0 ceildiv 128) * 128)>()[%12]
  scf.forall (%arg0, %arg1) in (%25, 64) {
    %26 = affine.min affine_map<(d0)[s0] -> (d0 + s0 + 1, (s0 ceildiv 128) * 128)>(%arg0)[%12]
    %27 = affine.min affine_map<(d0) -> (4096, d0 * 64 + 64)>(%arg1)
    %28 = affine.apply affine_map<(d0, d1)[s0] -> (d0 - d1 - s0)>(%26, %arg0)[%12]
    %29 = affine.apply affine_map<(d0, d1) -> (d0 - d1 * 64)>(%27, %arg1)
    scf.forall (%arg2, %arg3) in (%28, %29) {
      %30 = affine.apply affine_map<(d0, d1) -> (d0 + d1 * 64)>(%arg3, %arg1)
      %31 = affine.apply affine_map<(d0, d1)[s0] -> (d0 + d1 + s0)>(%arg2, %arg0)[%12]
      %32 = affine.min affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 + 1, d1 + s0 + 1, (s0 ceildiv 128) * 128)>(%arg2, %arg0)[%12]
      %33 = affine.min affine_map<(d0, d1) -> (4096, d1 * 64 + 64, d0 + d1 * 64 + 1)>(%arg3, %arg1)
      scf.for %arg4 = %31 to %32 step %c1 {
        %34:2 = affine.delinearize_index %arg4 into (%22, 128) : index, index
        %35:4 = affine.delinearize_index %34#1 into (4, 4, 2, 4) : index, index, index, index
        scf.for %arg5 = %30 to %33 step %c1 {
          %36:2 = affine.delinearize_index %arg5 into (32, 128) : index, index
          %37:2 = affine.delinearize_index %36#1 into (16, 8) : index, index
          memref.store %cst, %16[%34#0, %36#0, %35#0, %35#2, %37#1, %35#1, %37#0, %35#3] : memref<?x32x4x2x8x4x16x4xf16, strided<[524288, 16384, 4096, 2048, 256, 64, 4, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        }
      }
    } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
  } {mapping = [#iree_codegen.workgroup_mapping<x>, #iree_codegen.workgroup_mapping<y>]}
  return
}

// -----// IR Dump Before TileLargeTensorsPass (iree-codegen-tile-large-tensors) //----- //
func.func @prefill_bs4$async_dispatch_0_elementwise_broadcast_Dx4096_i64xf16() attributes {translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [128, 1, 1] subgroup_size = 64>} {
  %c1 = arith.constant 1 : index
  %true = arith.constant true
  %c128 = arith.constant 128 : index
  %cst = arith.constant 0.000000e+00 : f16
  %c32_i64 = arith.constant 32 : i64
  %c0 = arith.constant 0 : index
  %0 = hal.interface.constant.load layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(0) : i32
  %1 = hal.interface.constant.load layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(1) : i32
  %2 = hal.interface.constant.load layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) ordinal(2) : i32
  %3 = arith.extui %1 : i32 to i64
  %4 = arith.shli %3, %c32_i64 : i64
  %5 = arith.extui %0 : i32 to i64
  %6 = arith.ori %5, %4 : i64
  %7 = arith.index_castui %6 : i64 to index
  %8 = arith.index_castui %2 : i32 to index
  %9:2 = util.assume.int 
      %7<umin = 2097152, umax = 8587837440>, 
      %8<umin = 128, umax = 524160, udiv = 128>
    : index, index
  %10 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : memref<128256x4096xf16, #hal.descriptor_type<storage_buffer>>
  %11 = amdgpu.fat_raw_buffer_cast %10 resetOffset : memref<128256x4096xf16, #hal.descriptor_type<storage_buffer>> to memref<128256x4096xf16, #amdgpu.address_space<fat_raw_buffer>>
  %12 = iree_tensor_ext.dispatch.workload.ordinal %9#1, 0 : index
  %13 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : memref<?xi64, #hal.descriptor_type<storage_buffer>>{%12}
  %14 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : memref<?x4096xf16, #hal.descriptor_type<storage_buffer>>{%12}
  %15 = affine.apply affine_map<()[s0] -> (s0 ceildiv 128)>()[%12]
  %16 = hal.interface.binding.subspan layout(<constants = 3, bindings = [#hal.pipeline.binding<storage_buffer, ReadOnly>, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(3) alignment(64) offset(%9#0) flags(Indirect) : memref<?x32x4x2x8x4x16x4xf16, strided<[524288, 16384, 4096, 2048, 256, 64, 4, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>{%15}
  %17 = iree_codegen.load_from_buffer %11 : memref<128256x4096xf16, #amdgpu.address_space<fat_raw_buffer>> -> tensor<128256x4096xf16>
  %18 = affine.apply affine_map<()[s0] -> (s0 floordiv 128)>()[%12]
  %19 = tensor.empty(%18) : tensor<?x128x4096xf16>
  %expand_shape = memref.expand_shape %14 [[0, 1], [2]] output_shape [%18, 128, 4096] : memref<?x4096xf16, #hal.descriptor_type<storage_buffer>> into memref<?x128x4096xf16, #hal.descriptor_type<storage_buffer>>
  %expand_shape_0 = memref.expand_shape %13 [[0, 1]] output_shape [%18, 128] : memref<?xi64, #hal.descriptor_type<storage_buffer>> into memref<?x128xi64, #hal.descriptor_type<storage_buffer>>
  %20 = iree_codegen.load_from_buffer %expand_shape_0 : memref<?x128xi64, #hal.descriptor_type<storage_buffer>> -> tensor<?x128xi64>
  %21 = affine.apply affine_map<()[s0] -> ((s0 ceildiv 128) * 128)>()[%12]
  %22 = arith.divsi %21, %c128 : index
  %23 = tensor.empty(%15) : tensor<?x32x4x2x8x4x16x4xf16>
  %24:2 = scf.forall (%arg0, %arg1, %arg2) in (%18, 128, 32) shared_outs(%arg3 = %19, %arg4 = %23) -> (tensor<?x128x4096xf16>, tensor<?x32x4x2x8x4x16x4xf16>) {
    %26 = affine.apply affine_map<(d0) -> (d0 * 128)>(%arg2)
    %extracted_slice = tensor.extract_slice %arg3[%arg0, %arg1, %26] [1, 1, 128] [1, 1, 1] : tensor<?x128x4096xf16> to tensor<1x1x128xf16>
    %extracted_slice_1 = tensor.extract_slice %20[%arg0, %arg1] [1, 1] [1, 1] : tensor<?x128xi64> to tensor<1x1xi64>
    %27 = scf.forall (%arg5, %arg6, %arg7) in (1, 1, 128) shared_outs(%arg8 = %extracted_slice) -> (tensor<1x1x128xf16>) {
      %extracted_slice_2 = tensor.extract_slice %20[%arg0, %arg1] [1, 1] [1, 1] : tensor<?x128xi64> to tensor<1x1xi64>
      %extracted_slice_3 = tensor.extract_slice %extracted_slice_2[%arg5, %arg6] [1, 1] [1, 1] : tensor<1x1xi64> to tensor<1x1xi64>
      %extracted_slice_4 = tensor.extract_slice %extracted_slice_1[%arg5, %arg6] [1, 1] [1, 1] : tensor<1x1xi64> to tensor<1x1xi64>
      %extracted_slice_5 = tensor.extract_slice %arg3[%arg0, %arg1, %26] [1, 1, 128] [1, 1, 1] : tensor<?x128x4096xf16> to tensor<1x1x128xf16>
      %extracted_slice_6 = tensor.extract_slice %extracted_slice_5[%arg5, %arg6, %arg7] [1, 1, 1] [1, 1, 1] : tensor<1x1x128xf16> to tensor<1x1x1xf16>
      %extracted_slice_7 = tensor.extract_slice %arg8[%arg5, %arg6, %arg7] [1, 1, 1] [1, 1, 1] : tensor<1x1x128xf16> to tensor<1x1x1xf16>
      %29 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_4 : tensor<1x1xi64>) outs(%extracted_slice_7 : tensor<1x1x1xf16>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 128]]>} {
      ^bb0(%in: i64, %out: f16):
        %30 = linalg.index 2 : index
        %31 = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%arg7)[%30]
        %32 = affine.apply affine_map<(d0)[s0] -> (d0 * 128 + s0)>(%arg2)[%31]
        %33 = arith.index_cast %in : i64 to index
        %extracted = tensor.extract %17[%33, %32] : tensor<128256x4096xf16>
        linalg.yield %extracted : f16
      } -> tensor<1x1x1xf16>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %29 into %arg8[%arg5, %arg6, %arg7] [1, 1, 1] [1, 1, 1] : tensor<1x1x1xf16> into tensor<1x1x128xf16>
      }
    } {mapping = [#gpu.thread<linear_dim_2>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
    %28 = iree_linalg_ext.map_scatter %27 into %arg4 {
    ^bb0(%arg5: index, %arg6: index, %arg7: index):
      %29 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg5, %arg0)
      %30 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg6, %arg1)
      %31 = affine.apply affine_map<(d0, d1) -> (d0 + d1 * 128)>(%arg7, %arg2)
      %32:2 = affine.delinearize_index %31 into (32, 128) : index, index
      %33:4 = affine.delinearize_index %30 into (4, 4, 2, 4) : index, index, index, index
      %34:2 = affine.delinearize_index %32#1 into (16, 8) : index, index
      iree_linalg_ext.yield %29, %32#0, %33#0, %33#2, %34#1, %33#1, %34#0, %33#3, %true : index, index, index, index, index, index, index, index, i1
    } : tensor<1x1x128xf16> into tensor<?x32x4x2x8x4x16x4xf16> -> tensor<?x32x4x2x8x4x16x4xf16>
    scf.forall.in_parallel {
      tensor.parallel_insert_slice %27 into %arg3[%arg0, %arg1, %26] [1, 1, 128] [1, 1, 1] : tensor<1x1x128xf16> into tensor<?x128x4096xf16>
      tensor.parallel_insert_slice %28 into %arg4[0, 0, 0, 0, 0, 0, 0, 0] [%15, 32, 4, 2, 8, 4, 16, 4] [1, 1, 1, 1, 1, 1, 1, 1] : tensor<?x32x4x2x8x4x16x4xf16> into tensor<?x32x4x2x8x4x16x4xf16>
    }
  } {mapping = [#iree_codegen.workgroup_mapping<z>, #iree_codegen.workgroup_mapping<y>, #iree_codegen.workgroup_mapping<x>]}
  iree_codegen.store_to_buffer %24#0, %expand_shape : tensor<?x128x4096xf16> into memref<?x128x4096xf16, #hal.descriptor_type<storage_buffer>>
  iree_codegen.store_to_buffer %24#1, %16 : tensor<?x32x4x2x8x4x16x4xf16> into memref<?x32x4x2x8x4x16x4xf16, strided<[524288, 16384, 4096, 2048, 256, 64, 4, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
  %25 = affine.apply affine_map<()[s0] -> (-s0 + (s0 ceildiv 128) * 128)>()[%12]
  scf.forall (%arg0, %arg1) in (%25, 64) {
    %26 = affine.min affine_map<(d0)[s0] -> (d0 + s0 + 1, (s0 ceildiv 128) * 128)>(%arg0)[%12]
    %27 = affine.min affine_map<(d0) -> (4096, d0 * 64 + 64)>(%arg1)
    %28 = affine.apply affine_map<(d0, d1)[s0] -> (d0 - d1 - s0)>(%26, %arg0)[%12]
    %29 = affine.apply affine_map<(d0, d1) -> (d0 - d1 * 64)>(%27, %arg1)
    scf.forall (%arg2, %arg3) in (%28, %29) {
      %30 = affine.apply affine_map<(d0, d1) -> (d0 + d1 * 64)>(%arg3, %arg1)
      %31 = affine.apply affine_map<(d0, d1)[s0] -> (d0 + d1 + s0)>(%arg2, %arg0)[%12]
      %32 = affine.min affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 + 1, d1 + s0 + 1, (s0 ceildiv 128) * 128)>(%arg2, %arg0)[%12]
      %33 = affine.min affine_map<(d0, d1) -> (4096, d1 * 64 + 64, d0 + d1 * 64 + 1)>(%arg3, %arg1)
      scf.for %arg4 = %31 to %32 step %c1 {
        %34:2 = affine.delinearize_index %arg4 into (%22, 128) : index, index
        %35:4 = affine.delinearize_index %34#1 into (4, 4, 2, 4) : index, index, index, index
        scf.for %arg5 = %30 to %33 step %c1 {
          %36:2 = affine.delinearize_index %arg5 into (32, 128) : index, index
          %37:2 = affine.delinearize_index %36#1 into (16, 8) : index, index
          memref.store %cst, %16[%34#0, %36#0, %35#0, %35#2, %37#1, %35#1, %37#0, %35#3] : memref<?x32x4x2x8x4x16x4xf16, strided<[524288, 16384, 4096, 2048, 256, 64, 4, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
        }
      }
    } {mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
  } {mapping = [#iree_codegen.workgroup_mapping<x>, #iree_codegen.workgroup_mapping<y>]}
  return
}

hanhanW avatar Sep 29 '25 22:09 hanhanW

(minor note: @Abhishek-Varma please attach the repro next time. Just in case if others would pick it up without getting back to you. It saves communication time.)

My bad. I missed adding repro for the last bits. I'll keep this in mind for next time.

So I was able to triage it further and add a fix.

The root cause is MapScatterOp is NOT handled to obtain tile sizes : https://github.com/iree-org/iree/blob/6538d9a8a782b8706e7e7634de96743089f4c97b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/DerivedConfigUtils.cpp#L164-L178

It worked for TileAndDistributeToWorkgroupsUsingForallOpPass because there we obtain the workgroup_size from config and apply it. But in case of GPUGreedilyDistributeToThreadsPass we obtain individual op's tile size here via DerivedThreadConfigAttr.

Now, to obtain loop bounds for MapScatterOp I can't use getStaticLoopRanges! For that I needed to know the semantics involved for the op - for this I tried to look into DecomposeMapScatterOpPass as it involves : MapScatterOp -> GenericOp -> Vectorization. Through that GenericOp I figured it's the value shape which is being written into the main buffer (trivially speaking). Once I used that to obtain the loop bounds, I used the same implementation as ScatterOp to create tile sizes.

Here is the IR - IR with fix - can someone confirm if it looks okay? I was able to compiled the dispatch e2e without any issue now.

I'll clean up the code and try to raise a PR with the above fix.

CC: @hanhanW @MaheshRavishankar @jtuyls @Max191

Abhishek-Varma avatar Sep 30 '25 12:09 Abhishek-Varma

Nice work @Abhishek-Varma! The DerivedThreadConfigAttr implementation for map_scatter is useful to have, so I think you can go ahead and send the PR for review.

There is still a slight issue with the final IR here, though. After the change, there are 2 scf.forall thread loops next to each other:

      %27 = scf.forall (%arg5, %arg6, %arg7) in (1, 1, 128) shared_outs(%arg8 = %extracted_slice) -> (tensor<1x1x128xf16>) {
        %extracted_slice_2 = tensor.extract_slice %arg8[0, 0, %arg7] [1, 1, 1] [1, 1, 1] : tensor<1x1x128xf16> to tensor<1x1x1xf16>
        %29 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_1 : tensor<1x1xi64>) outs(%extracted_slice_2 : tensor<1x1x1xf16>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 128]]>} {
        ^bb0(%in: i64, %out: f16):
          %30 = affine.apply affine_map<(d0)[s0] -> (d0 * 128 + s0)>(%arg2)[%arg7]
          %31 = arith.index_cast %in : i64 to index
          %extracted = tensor.extract %17[%31, %30] : tensor<128256x4096xf16>
          linalg.yield %extracted : f16
        } -> tensor<1x1x1xf16>
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %29 into %arg8[0, 0, %arg7] [1, 1, 1] [1, 1, 1] : tensor<1x1x1xf16> into tensor<1x1x128xf16>
        }
      } {mapping = [#gpu.thread<linear_dim_2>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}
      %28 = scf.forall (%arg5, %arg6, %arg7) in (1, 1, 128) shared_outs(%arg8 = %arg4) -> (tensor<?x32x4x2x8x4x16x4xf16>) {
        %extracted_slice_2 = tensor.extract_slice %extracted_slice[0, 0, %arg7] [1, 1, 1] [1, 1, 1] : tensor<1x1x128xf16> to tensor<1x1x1xf16>
        %29 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%extracted_slice_1 : tensor<1x1xi64>) outs(%extracted_slice_2 : tensor<1x1x1xf16>) attrs =  {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 1, 128]]>} {
        ^bb0(%in: i64, %out: f16):
          %31 = affine.apply affine_map<(d0)[s0] -> (d0 * 128 + s0)>(%arg2)[%arg7]
          %32 = arith.index_cast %in : i64 to index
          %extracted = tensor.extract %17[%32, %31] : tensor<128256x4096xf16>
          linalg.yield %extracted : f16
        } -> tensor<1x1x1xf16>
        %30 = iree_linalg_ext.map_scatter %29 into %arg8 {
        ^bb0(%arg9: index, %arg10: index, %arg11: index):
          %31 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg9, %arg0)
          %32 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%arg10, %arg1)
          %33 = affine.apply affine_map<(d0, d1, d2) -> (d0 * 128 + d1 + d2)>(%arg2, %arg11, %arg7)
          %34:2 = affine.delinearize_index %33 into (32, 128) : index, index
          %35:4 = affine.delinearize_index %32 into (4, 4, 2, 4) : index, index, index, index
          %36:2 = affine.delinearize_index %34#1 into (16, 8) : index, index
          iree_linalg_ext.yield %31, %34#0, %35#0, %35#2, %36#1, %35#1, %36#0, %35#3, %true : index, index, index, index, index, index, index, index, i1
        } : tensor<1x1x1xf16> into tensor<?x32x4x2x8x4x16x4xf16> -> tensor<?x32x4x2x8x4x16x4xf16>
        %dim = tensor.dim %arg4, %c0 : tensor<?x32x4x2x8x4x16x4xf16>
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %30 into %arg8[0, 0, 0, 0, 0, 0, 0, 0] [%dim, 32, 4, 2, 8, 4, 16, 4] [1, 1, 1, 1, 1, 1, 1, 1] : tensor<?x32x4x2x8x4x16x4xf16> into tensor<?x32x4x2x8x4x16x4xf16>
        }
      } {mapping = [#gpu.thread<linear_dim_2>, #gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]}

This likely happened because we have a producer/consumer graph that looks like this:

          linalg.generic
              /     \
     map_scatter   store
          |
        store

The GPUGreedilyDistributeToThreadsPass will work bottom up, starting from the end of the block, distributing each operation one by one, and then fusing any producers into its loop. This means that the following happened:

  1. Greedily distribute map_scatter, creating an scf.forall
  2. Fuse the map_scatter producer (linalg.generic) into the scf.forall loop
    • The linalg.generic is not erased, because it still has another user (the store)
    • Now, we have the scf.forall with the map_scatter, and a copy of linalg.generic, and then the original linalg.generic outside the loop
  3. The original linalg.generic outside the loop is greedily distributed to threads, creating another scf.forall.

The reason this is problematic is that we now have 2 copies of the linalg.generic op, so it could potentially be reading the inputs twice.

@Abhishek-Varma you should go ahead and send a PR for what you have that fixes the compilation failure, since that is useful already. But let's keep this issue open, since we will want to fix the issue of the multiple copies of the linalg.generic. It could be tricky, so it might need some further discussion.

Max191 avatar Sep 30 '25 14:09 Max191

So the reason why the input dispatch doesn't go through GPUTileAndFuse pipeline is because of the following check which bails it out when trying to attempt assigning GPUTileAndFuse pipeline :- https://github.com/iree-org/iree/blob/3351570c5c26d81ed130cff2cbfd8d68cec6bd20/compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp#L1062-L1066

The API hasIndexSemantics returns true if it finds any linalg.index op in the body of linalg op. I understand the intention was to guard it for gather-like operation, but the input dispatch is an elementwise broadcast.

@qedawkins @hanhanW @Max191 - wanted to know your thoughts on this. Perhaps we should either have an element-wise/element-wise broadcast related check OR just completely get rid of this checks as the code comment suggests it was added just to simply tile+fuse - please let me know how to go about this.

I tried removing the check - the dispatch then goes through GPUTileAndFuse as expected without having to hardcoding GPUTileAndFuse instead of GPUVectorization (as mentioned in the current issue doc). I even tried running llama 8B on gfx950 after removing the check - it compiled successfully and ran.

CC: @MaheshRavishankar @hanhanW @jtuyls @Max191 @qedawkins

Abhishek-Varma avatar Oct 03 '25 20:10 Abhishek-Varma

cc @nirvedhmeshram . I think he is removing that condition here https://github.com/iree-org/iree/pull/22195 . That landed and got subsequently reverted, but I think Nirvedh is working on relanding it.

MaheshRavishankar avatar Oct 06 '25 17:10 MaheshRavishankar

My PR (here is the new one https://github.com/iree-org/iree/pull/22223), only relaxes the condition for linalg.index ops in the body as it checks for LinalgExt::isGatherlikeOp(linalgOp) But dispatch in this PR is actually gather like (since it also has an extract_slice in the body)

    %21 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%20 : tensor<?x128xi64>) outs(%19 : tensor<?x128x4096xf16>) {
    ^bb0(%in: i64, %out: f16):
      %26 = linalg.index 2 : index
      %27 = arith.index_cast %in : i64 to index
      %extracted = tensor.extract %17[%27, %26] : tensor<128256x4096xf16>
      linalg.yield %extracted : f16
    } -> tensor<?x128x4096xf16>

so it would still fail to go to TileAndFuse, I think if we are supporting gather like dispatches in TileAndFuse then that would be a separate PR after the e2e support for them is added.

nirvedhmeshram avatar Oct 06 '25 21:10 nirvedhmeshram

I'm using this issue as the main issue for tracking "bring up llama8b fp8 on mi350".

@Abhishek-Varma can you help generate the metrics similar to this? So we can see the remaining work for the model.

To me, the bar of the bring-up is having it compiled and run; generate the metrics. We can start the work stream for performance burn down after the bring-up.

hanhanW avatar Oct 14 '25 16:10 hanhanW

Listing down here the perf breakdown for non-data tiled vs data tiled compilation for llama 8b on gfx350. The IR has been obtained from here.

No Data Tiling (with uk) With Data Tiling (no uk) With Data Tiling (with uk)
Benchmark latency 256 ms 274 ms 331 ms
Golden Dispatch Count 1486 1936 1712
Dispatch Name No Data Tiling (ms) Data Tiling (ms)
dispatch_644_matmul_like__Dx128256x409_f32 81.26 77.78
dispatch_19_matmul_like__Dx14336x4096_f8e4m3fnxf8e4m3fnxf32 23.24 38.44
dispatch_15_attention_4x8x4xDx128xf8e4m3fn_generic 36.8 36.42
dispatch_20_matmul_like__Dx14336x4096_f8e4m3fnxf8e4m3fnxf32 29.32 31.8
dispatch_21_matmul_like__Dx4096x14336_f8e4m3fnxf8e4m3fnxf32 19.36 21.72
dispatch_645_reduction_Dx128256_f32xf32xi64 8.81 8.71
dispatch_1_matmul_like_Dx4096x4096_f8e4m3fnxf8e4m3fnxf32 6.85 6.9
dispatch_16_matmul_like_Dx4096x4096_f8e4m3fnxf8e4m3fnxf32 7.17 6.3
dispatch_22_reduction_Dx4096_f32 -- 3.91

CC: @hanhanW @MaheshRavishankar @jtuyls

Abhishek-Varma avatar Nov 21 '25 13:11 Abhishek-Varma

Thanks @Abhishek-Varma ! This is a good breakdown. Can you also add a column for e2e performance?

Few questions:

  • I remember that there are no additional encoding dispatch. I.e., the number of dispatch is the same between them. Can you confirm it?
  • If the numbers of dispatches match, the missing number of dispatch_22_reduction_Dx4096_f32 is more like we haven't found it yet, right? They should have the same performance if encodings are not involved. So can you check if the dispatch has relayout ops or not?
  • We'll need to look at dispatch_19 in next steps. I think you identified that ukernel is not kicked in for few models, is it the case? If so, we can use some guidance from @Yu-Zhewen here.

Feel free to correct me if I miss something, thanks again for the breakdown!

hanhanW avatar Nov 21 '25 19:11 hanhanW

Listing down here the perf breakdown for non-data tiled vs data tiled compilation for llama 8b on gfx350. The IR has been obtained from here.

No Data Tiling With Data Tiling Benchmark latency 256 ms 274 ms Golden Dispatch Count 1486 1936 Dispatch Name No Data Tiling (ms) Data Tiling (ms) dispatch_644_matmul_like__Dx128256x409_f32 81.26 77.78 dispatch_19_matmul_like__Dx14336x4096_f8e4m3fnxf8e4m3fnxf32 23.24 38.44 dispatch_15_attention_4x8x4xDx128xf8e4m3fn_generic 36.8 36.42 dispatch_20_matmul_like__Dx14336x4096_f8e4m3fnxf8e4m3fnxf32 29.32 31.8 dispatch_21_matmul_like__Dx4096x14336_f8e4m3fnxf8e4m3fnxf32 19.36 21.72 dispatch_645_reduction_Dx128256_f32xf32xi64 8.81 8.71 dispatch_1_matmul_like_Dx4096x4096_f8e4m3fnxf8e4m3fnxf32 6.85 6.9 dispatch_16_matmul_like_Dx4096x4096_f8e4m3fnxf8e4m3fnxf32 7.17 6.3 dispatch_22_reduction_Dx4096_f32 -- 3.91 CC: @hanhanW @MaheshRavishankar @jtuyls

Let's take a look at this together in the data-tiling sync. It looks like we can close the gap if we can make dispatch_19/20/21 as fast as non-data-tiling. (15+2+2 ms ~= 20 ms)

The golden dispatch count mismatches my impression. I did not see additional encoding dispatches last time when I checked the dump: https://github.com/iree-org/iree/pull/22444#issuecomment-3463236728

Maybe something is off. How do you get those numbers? Can you prepare the fresh dump like what I did? Here is my old script: https://gist.githubusercontent.com/hanhanW/f3011926ac6edd218d15c58d5c4ffa97/raw/920f678fb32d6e8259d4a175f992437103de72e7/compile-8b-fp8.sh

We can check the IR dump after iree-dispatch-creation-convert-encoding-to-flow and see if all the flow.encode ops take weights/globals operands. If so, we should have the identical dispatch numbers between dt and non-dt. Then something is off in golden dispatch count.

anyway, this is a very good progress; let's burn down the perf issue!

hanhanW avatar Nov 25 '25 06:11 hanhanW

Hi @hanhanW @MaheshRavishankar

@Yu-Zhewen helped in triaging the issue with ukernels not being replaced for f8e4m3fn.

I've updated here the e2e results : e2e results.

But adding the following trace perf here with ukernel. Image

Abhishek-Varma avatar Nov 26 '25 13:11 Abhishek-Varma

Ah...so adding ukernels actually made data tiling perform worse (331 ms vs 274 ms)? We might need to check the waveforms to see what’s really happening.

Yu-Zhewen avatar Nov 26 '25 13:11 Yu-Zhewen

Ah...so adding ukernels actually made data tiling perform worse (331 ms vs 274 ms)? We might need to check the waveforms to see what’s really happening.

Let's definitely get a threadtrace to see what's happening, but it's not surprising that the ukernel doesn't work well. The mi350 chips have more shared mem, vgpr space, etc., so we can afford to have larger tile sizes. At a minimum, we probably need to adjust the tile size to better fit the larger chip.

Max191 avatar Dec 01 '25 15:12 Max191

Closing the issue because we successfully brought up the model. Now the issue is about performance, and let's move the discussion to https://github.com/iree-org/iree/issues/21958

(I moved the last three comments to the new issue, so we won't miss the comments.)

hanhanW avatar Dec 15 '25 07:12 hanhanW