iree icon indicating copy to clipboard operation
iree copied to clipboard

[CodeGen][SPIRV] Lowering for clustered reduce not implemented

Open Muzammiluddin-Syed-ECE opened this issue 7 months ago • 1 comments

What happened?

Context

To make effective use of DPP operations available to AMD GPU's, the PR below changed the implementation of warp reduction to preserve subgroup_reduce ops rather than immediately lowering to butterfly shuffling using gpu.shuffle xor ops. The goal of preserving the subgroup_reduce ops is to enable lowering to target-specific ops later in the pipeline, if such ops exist. So, this PR allows you to express reduction within warps and across warps using subgroup_reduce ops.

https://github.com/iree-org/iree/pull/20468

However, there is incomplete support in SPIRV for clustered subgroup_reduce which makes reduction across multiple warps difficult (you're not guaranteed to have 64 x 64 threads within a workgroup), so being able to perform subgroup_reduce over <64 threads is useful.

Reproduction:

Input.mlir

module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformArithmetic, DotProduct, DotProductInput4x8BitPacked, DotProductInputAll, DotProductInput4x8Bit], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product]>, ARM, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512 : i32, 512 : i32, 512 : i32], subgroup_size = 16, min_subgroup_size = 16, max_subgroup_size = 16, cooperative_matrix_properties_khr = []>>} {
  func.func @subgroup_reduce() {
    %c7_i32 = arith.constant 7 : i32
    %c0_i32 = arith.constant 0 : i32
    %c16_i32 = arith.constant 16 : i32
    %c128_i32 = arith.constant 128 : i32
    %c2 = arith.constant 2 : index
    %c256 = arith.constant 256 : index
    %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
    %cst_0 = arith.constant 0.000000e+00 : f32
    %c0 = arith.constant 0 : index
    %thread_id_x = gpu.thread_id  x upper_bound 128
    %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>]>) binding(0) alignment(64) offset(%c0) : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>{%c256}
    %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>]>) binding(1) alignment(64) offset(%c0) : memref<?xf32, #spirv.storage_class<StorageBuffer>>{%c2}
    %workgroup_id_x = hal.interface.workgroup.id[0] upper_bound 2 : index
    %2 = arith.index_castui %workgroup_id_x : index to i32
    %3 = arith.muli %2, %c128_i32 overflow<nsw> : i32
    %4 = arith.index_castui %thread_id_x : index to i32
    %5 = arith.addi %3, %4 : i32
    %6 = arith.index_castui %5 : i32 to index
    %7 = memref.load %0[%6] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
    %8 = arith.addf %7, %cst : vector<4xf32>
    %9 = vector.reduction <add>, %8 : vector<4xf32> into f32
    %10 = gpu.subgroup_reduce  add %9 cluster(size = 16) : (f32) -> f32
    %alloc = memref.alloc() : memref<8xf32, #spirv.storage_class<Workgroup>>
    %11 = arith.divui %4, %c16_i32 : i32
    %12 = arith.index_castui %11 : i32 to index
    %13 = arith.remui %4, %c16_i32 : i32
    %14 = arith.cmpi eq, %13, %c0_i32 : i32
    scf.if %14 {
      memref.store %10, %alloc[%12] : memref<8xf32, #spirv.storage_class<Workgroup>>
    }
    gpu.barrier
    %15 = arith.minui %13, %c7_i32 : i32
    %16 = arith.index_castui %15 : i32 to index
    %17 = memref.load %alloc[%16] : memref<8xf32, #spirv.storage_class<Workgroup>>
    %18 = gpu.subgroup_reduce  add %17 cluster(size = 8) : (f32) -> f32
    %19 = arith.addf %18, %cst_0 : f32
    %20 = arith.cmpi eq, %4, %c0_i32 : i32
    scf.if %20 {
      memref.store %19, %1[%workgroup_id_x] : memref<?xf32, #spirv.storage_class<StorageBuffer>>
    }
    return
  }
}

Command:

iree-opt --iree-gpu-test-target=valhall1 --pass-pipeline='builtin.module(iree-convert-to-spirv)' <Input.mlir>

Steps to reproduce your issue

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

What component(s) does this issue relate to?

No response

Version information

No response

Additional context

No response

Muzammiluddin-Syed-ECE avatar May 21 '25 03:05 Muzammiluddin-Syed-ECE

Here is a change to add clustered reduce support for SPIRV https://github.com/llvm/llvm-project/pull/141402.

fairywreath avatar May 25 '25 09:05 fairywreath