[CodeGen][SPIRV] Lowering for clustered reduce not implemented
What happened?
Context
To make effective use of DPP operations available to AMD GPU's, the PR below changed the implementation of warp reduction to preserve subgroup_reduce ops rather than immediately lowering to butterfly shuffling using gpu.shuffle xor ops. The goal of preserving the subgroup_reduce ops is to enable lowering to target-specific ops later in the pipeline, if such ops exist. So, this PR allows you to express reduction within warps and across warps using subgroup_reduce ops.
https://github.com/iree-org/iree/pull/20468
However, there is incomplete support in SPIRV for clustered subgroup_reduce which makes reduction across multiple warps difficult (you're not guaranteed to have 64 x 64 threads within a workgroup), so being able to perform subgroup_reduce over <64 threads is useful.
Reproduction:
Input.mlir
module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, Float16, Int16, Int8, StorageBuffer16BitAccess, StorageUniform16, StoragePushConstant16, StorageBuffer8BitAccess, UniformAndStorageBuffer8BitAccess, StoragePushConstant8, GroupNonUniformShuffle, GroupNonUniformShuffleRelative, GroupNonUniformArithmetic, DotProduct, DotProductInput4x8BitPacked, DotProductInputAll, DotProductInput4x8Bit], [SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_integer_dot_product]>, ARM, #spirv.resource_limits<max_compute_shared_memory_size = 32768, max_compute_workgroup_invocations = 512, max_compute_workgroup_size = [512 : i32, 512 : i32, 512 : i32], subgroup_size = 16, min_subgroup_size = 16, max_subgroup_size = 16, cooperative_matrix_properties_khr = []>>} {
func.func @subgroup_reduce() {
%c7_i32 = arith.constant 7 : i32
%c0_i32 = arith.constant 0 : i32
%c16_i32 = arith.constant 16 : i32
%c128_i32 = arith.constant 128 : i32
%c2 = arith.constant 2 : index
%c256 = arith.constant 256 : index
%cst = arith.constant dense<0.000000e+00> : vector<4xf32>
%cst_0 = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%thread_id_x = gpu.thread_id x upper_bound 128
%0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>]>) binding(0) alignment(64) offset(%c0) : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>{%c256}
%1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer>, #hal.pipeline.binding<storage_buffer>]>) binding(1) alignment(64) offset(%c0) : memref<?xf32, #spirv.storage_class<StorageBuffer>>{%c2}
%workgroup_id_x = hal.interface.workgroup.id[0] upper_bound 2 : index
%2 = arith.index_castui %workgroup_id_x : index to i32
%3 = arith.muli %2, %c128_i32 overflow<nsw> : i32
%4 = arith.index_castui %thread_id_x : index to i32
%5 = arith.addi %3, %4 : i32
%6 = arith.index_castui %5 : i32 to index
%7 = memref.load %0[%6] : memref<?xvector<4xf32>, #spirv.storage_class<StorageBuffer>>
%8 = arith.addf %7, %cst : vector<4xf32>
%9 = vector.reduction <add>, %8 : vector<4xf32> into f32
%10 = gpu.subgroup_reduce add %9 cluster(size = 16) : (f32) -> f32
%alloc = memref.alloc() : memref<8xf32, #spirv.storage_class<Workgroup>>
%11 = arith.divui %4, %c16_i32 : i32
%12 = arith.index_castui %11 : i32 to index
%13 = arith.remui %4, %c16_i32 : i32
%14 = arith.cmpi eq, %13, %c0_i32 : i32
scf.if %14 {
memref.store %10, %alloc[%12] : memref<8xf32, #spirv.storage_class<Workgroup>>
}
gpu.barrier
%15 = arith.minui %13, %c7_i32 : i32
%16 = arith.index_castui %15 : i32 to index
%17 = memref.load %alloc[%16] : memref<8xf32, #spirv.storage_class<Workgroup>>
%18 = gpu.subgroup_reduce add %17 cluster(size = 8) : (f32) -> f32
%19 = arith.addf %18, %cst_0 : f32
%20 = arith.cmpi eq, %4, %c0_i32 : i32
scf.if %20 {
memref.store %19, %1[%workgroup_id_x] : memref<?xf32, #spirv.storage_class<StorageBuffer>>
}
return
}
}
Command:
iree-opt --iree-gpu-test-target=valhall1 --pass-pipeline='builtin.module(iree-convert-to-spirv)' <Input.mlir>
Steps to reproduce your issue
- Go to '...'
- Click on '....'
- Scroll down to '....'
- See error
What component(s) does this issue relate to?
No response
Version information
No response
Additional context
No response
Here is a change to add clustered reduce support for SPIRV https://github.com/llvm/llvm-project/pull/141402.