iree icon indicating copy to clipboard operation
iree copied to clipboard

[GPU] Clustered Subgroup Reduction

Open Groverkss opened this issue 6 months ago • 4 comments

Request description

Motivation

A pattern we notice in flash attention kernels is:

A: tensor<16x16xf16>
B: tensor<16x16xf16>
C: tensor<16x16xf16>

D : tensor<16x16xf16> = matmul(A, B, C)
E : tensor<16x1xf16>  = reduce(D, dim=1)
F : tensor<16x16xf16> = broadcast(E, dim=1)

When optimizing for matmul intrinsics, for performance, we prefer to keep the computation here entierly in registers. So the distribution of data on threads cannot change. We primarily optimize for matmul intrinsics on GPUs, so the thread distribution follows that.

Accordingly, the thread distribution for tensor D follows that of output of a matmul intrinsic. An example of such a thread distribution for a 16x16 shape is each thread carrying a vector<1x4xf16>, distributed over a 16x4 thread grid.

The data can be though to be distributed on threads as (here the numbers represent thread ids, and the matrix over which they are distributed is the tensor D):

[0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3]
[4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7]
[8 8 8 8 9 9 9 9 ...            ]
.
.
.
[]

When reducing this along dimension = 1, we want to be doing multiple reductions in parallel in "clusters" of 4 threads. (An element carried by multiple thread ids is represented here by commas)

[0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3] --> [0, 1, 2, 3]
[4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7] --> [4, 5, 6, 7]
...

We will call such a reduction, a "clustered" reduction.

How we do it today

Today, we directly emit a bunch of gpu.shuffle ops and do the entire reduction lowering in one shot:

https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp#L402

Better lowering for clustered reductions

There is an existing operation in MLIR's gpu dialect, which can represent a reduction across threads in a subgroup:

https://mlir.llvm.org/docs/Dialects/GPU/#gpusubgroup_reduce-gpusubgroupreduceop

The limitation of this operation is that it uses all available threads in a subgroup to do the reduction, which means we cannot do a clustered reduction.

We would like to add support to this operation to do such clustered reductions.

Tasks

  • Add a new "cluster_size" attribute on the gpu.subgroup_reduce operation
  • Update the existing gpu.subgroup_reduce lowering to emit clustered subgroup reductions.
  • Update the vector distribution pattern to use clustered reductions.

Useful Links

Current subgroup_reduce lowering: https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp GPU Dialect documentation: https://mlir.llvm.org/docs/Dialects/GPU/#gpusubgroup_reduce-gpusubgroupreduceop IR definition for gpu.subgroup_reduce : https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td#L1196

What component(s) does this issue relate to?

MLIR, Compiler

Additional context

No response

Groverkss avatar Aug 07 '24 17:08 Groverkss