[gfx950][mxfp4] Avoid bank conflicts on matrix inputs
Analysis of traces (thanks @sebvince) has shown that the mxfp4 kernel spends a substantial amount of time stalled at LDS bank conflicts on the matrix inputs - that is, matrices A and B (as opposed to scale inputs, which are a separate issue).
As with other matmuls, swizzles are needed in order to eliminate these conflicts.
We've got a sense that this'll substantially improve performance (1.25x, say, maybe more?) but will need to validate that.
Notes on resolution: Discussion with @sebvince indicated that he'd done experiments on these same bank conflicts in the Wave kernel and that the traditional padding approach was not feasible, and that an XOR-based swizzle (or perhaps the row rotation swizzle) was needed. He linked the following files
https://github.com/sebvince/test_bank_conflict/blob/main/kernel_f32_async_hint.mlir and https://github.com/sebvince/test_bank_conflict/blob/main/kernel_f32_async_ref.mlir test_bank_conflict/kernel_f32_async_ref.mlir
@qedawkins What's a good way to quickly apply these swizzling strategies in IREE so we can get empirical results for how well bank conflicts are avoided / what the perf gains here are?
Misc. notes: To check bank conflict in ATT : rocprofv3 --att-perfcounter-ctrl 3 --att-perfcounters "SQ_LDS_BANK_CONFLICT" -- ./your-application
Ok, so, I tried a simple experiment, which made things get slower. I suspect that that's me holding it wrong (haven't checked traces yet, but)
The patch
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 4bd0012a8a..c85981fd42 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -12,6 +12,7 @@
#include "iree/compiler/Codegen/Common/CombineLayoutTransformation.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
@@ -359,6 +360,15 @@ static FailureOr<Value> gpuRequireMemSpaceAllocationFn(OpBuilder &builder,
allocType =
MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
AffineMap(), workgroupSpace);
+ if (isa<Float4E2M1FNType>(memRefType.getElementType())) {
+ auto flatAllocType = MemRefType::get(ArrayRef<int64_t>{memRefType.getNumElements()}, memRefType.getElementType(), AffineMap(), workgroupSpace);
+ Value flatAlloc = memref::AllocOp::create(builder, loc, flatAllocType);
+ Value swizzled = iree_compiler::IREE::Codegen::SwizzleHintOp::create(builder, loc, flatAlloc,
+ iree_compiler::IREE::Codegen::XORShuffleAttr::get(builder.getContext(), 256, 32, 256, 1));
+ ReassociationIndices reassoc = llvm::to_vector(llvm::seq(allocType.getRank()));
+ Value expanded = memref::ExpandShapeOp::create(builder, loc, allocType.getShape(), swizzled, {reassoc});
+ return expanded;
+ }
return memref::AllocOp::create(builder, loc, allocType, dynamicSizes)
.getResult();
}
The input
MLIR input file for this reproducer
//<dim> 8192 </dim>
//<dim> 8192 </dim>
//<dim> 256 </dim>
//<dim> 32 </dim>
//<input>tensor<8192x256x32xi8></input>
//<input>tensor<8192x256xi8></input>
//<input>tensor<8192x256x32xi8></input>
//<input>tensor<8192x256xi8></input>
!lhs = f4E2M1FN
!rhs = f4E2M1FN
!scale_ty = f8E8M0FNU
!A = tensor<8192x256x32xi8>
!A_i4 = tensor<8192x256x32xi4>
!B = tensor<8192x256x32xi8>
!B_i4 = tensor<8192x256x32xi4>
!A_fp4 = tensor<8192x256x32x!lhs>
!B_fp4 = tensor<8192x256x32x!rhs>
!A_scales = tensor<8192x256x!scale_ty>
!B_scales = tensor<8192x256x!scale_ty>
!A_s = tensor<8192x256xi8>
!B_s = tensor<8192x256xi8>
!C = tensor<8192x8192xf32>
#lhs_map = affine_map<(M, N, Ko, Kb) -> (M, Ko, Kb)>
#rhs_map = affine_map<(M, N, Ko, Kb) -> (N, Ko, Kb)>
#scale_m = affine_map<(M, N, Ko, Kb) -> (M, Ko)>
#scale_n = affine_map<(M, N, Ko, Kb) -> (N, Ko)>
#out_map = affine_map<(M, N, Ko, Kb) -> (M, N)>
func.func @scaled_matmul(%lhs : !A, %lhs_scales : !A_s, %rhs : !B, %rhs_scales : !B_s) -> !C {
%A_i4 = arith.trunci %lhs : !A to !A_i4
%B_i4 = arith.trunci %rhs : !B to !B_i4
%A_scales = arith.bitcast %lhs_scales : !A_s to !A_scales
%B_scales = arith.bitcast %rhs_scales : !B_s to !B_scales
%A = arith.bitcast %A_i4 : !A_i4 to !A_fp4
%B = arith.bitcast %B_i4 : !B_i4 to !B_fp4
%cst = arith.constant 0.000000e+00 : f32
%empty = tensor.empty() : !C
%C = linalg.fill ins(%cst : f32) outs(%empty : !C) -> !C
%D = linalg.generic {
indexing_maps = [#lhs_map, #rhs_map, #scale_m, #scale_n, #out_map],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]
} ins(%A, %B, %A_scales, %B_scales : !A_fp4, !B_fp4, !A_scales, !B_scales) outs(%C : !C) {
^bb0(%a: !lhs, %b: !rhs, %a_scale: !scale_ty, %b_scale: !scale_ty, %out: f32):
%1 = arith.scaling_extf %a, %a_scale : !lhs, !scale_ty to f32
%2 = arith.scaling_extf %b, %b_scale : !rhs, !scale_ty to f32
%3 = arith.mulf %1, %2 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
} -> !C
return %D : !C
}
The compile command
iree-compile --iree-opt-level=O3 --iree-hip-waves-per-eu=2 --iree-dispatch-creation-test-set-scaled-matmul-encodings=false xxx.mlir -o xxx.vmfb --iree-hal-target-device=hip --iree-hip-target=gfx950
This has, however, given me the side observation that we're missing an apply/delinearize cancelation that might not be helping.
Looking at the IR, it seems that the LDS tile is 32x128.
%14 = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<256, 32, 256, 1>] : memref<131072xf4E2M1FN, #gpu.address_space<workgroup>>
%expand_shape_0 = memref.expand_shape %14 [[0, 1, 2]] output_shape [32, 128, 32] : memref<131072xf4E2M1FN, #gpu.address_space<workgroup>> into memref<32x128x32xf4E2M1FN, #gpu.address_space<workgroup>>
Try using row_width=4096 (32 * 128) instead of 256.
XORShuffleAttr::get(builder.getContext(), 4096, 32, 4096, 1));
@sebvince Thanks for spotting the obvious mistake!
Now, to give times
Without xor:
BM_scaled_matmul/process_time/real_time_median 5.35 ms 5.36 ms 10 items_per_second=187.032/s
BM_scaled_matmul/process_time/real_time_stddev 0.002 ms 0.007 ms 10 items_per_second=0.0754851/s
With xor
BM_scaled_matmul/process_time/real_time_median 3.78 ms 3.80 ms 10 items_per_second=264.49/s
BM_scaled_matmul/process_time/real_time_stddev 0.003 ms 0.004 ms 10 items_per_second=0.216108/s
This is a 1.4x improvement!
The patch that yielded this
diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 4bd0012a8a..1e43604d53 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -12,6 +12,7 @@
#include "iree/compiler/Codegen/Common/CombineLayoutTransformation.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
#include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"
#include "iree/compiler/Codegen/LLVMGPU/Passes.h"
@@ -359,6 +360,15 @@ static FailureOr<Value> gpuRequireMemSpaceAllocationFn(OpBuilder &builder,
allocType =
MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
AffineMap(), workgroupSpace);
+ if (isa<Float4E2M1FNType>(memRefType.getElementType())) {
+ auto flatAllocType = MemRefType::get(ArrayRef<int64_t>{memRefType.getNumElements()}, memRefType.getElementType(), AffineMap(), workgroupSpace);
+ Value flatAlloc = memref::AllocOp::create(builder, loc, flatAllocType);
+ Value swizzled = iree_compiler::IREE::Codegen::SwizzleHintOp::create(builder, loc, flatAlloc,
+ iree_compiler::IREE::Codegen::XORShuffleAttr::get(builder.getContext(), 4096, 32, int64_t(), int64_t()));
+ ReassociationIndices reassoc = llvm::to_vector(llvm::seq(allocType.getRank()));
+ Value expanded = memref::ExpandShapeOp::create(builder, loc, allocType.getShape(), swizzled, {reassoc});
+ return expanded;
+ }
return memref::AllocOp::create(builder, loc, allocType, dynamicSizes)
.getResult();
}
So I think now the question will be how to productionize this. I think there's a bank conflict avoidance pass that ... adds these swizzle operations?
Plan:
- Teach swizzle hints to take tensors
- Make some swizzle attributes (perhaps with a wrapper) implement PromotionAttr so we can use existing promotion infrastructure and GPU config stuff to set them
- Make sure reduction tiling can look through the swizzle hints.
Per @kuhar - shouldn't use the swizzles directly: too many knobs for tuner, interface, etc. Just have an enum and materialize the parameters later
After investigating various shapes and experiments what we've found works the best for large gemms (note that im not using the definition of large gemm that used in GPUMMAHeurstics) is a combination of the following optimizations:
- FP4 swizzling with row width 256, access width 32
- K Tile Element Count: 256
- MN subgroup count: 64 (meaning each subgroup produces a 128x128 output tile)
- Scales packing
This provides a 51% geomean improvement across several shapes checked:
512_512_256_16 (-0.88%)
1024_512_256_16 (-25%)
8192_512_256_16
16384_512_256_16
53248_512_256_16
512_1024_8192_512
512_16384_8192_512
512_53248_8192_512
1024_1024_8192_512
512_16384_26624_1664
8192_1024_8192_512 (-10%)
16384_1024_8192_512
1024_16384_8192_512
53248_1024_8192_512
1024_53248_8192_512
1024_16384_26624_1664
8192_16384_8192_512
8192_53248_8192_512 (+108%)
16384_16384_8192_512
53248_16384_8192_512
16384_53248_8192_512 (+101%)
53248_53248_8192_512
8192_16384_26624_1664
16384_16384_26624_1664
53248_16384_26624_1664 (+112%)
This is the geomean improvement with each individual config disabled:
- Baseline (everything enabled): +51%
- FP4 swizzling: +33%
- K tile element count: +49%
- MN count: +31%
- ScalePacking: +48%
This implies that for shapes of very large arithmetic intensity we get the greatest boost in performance from improving the subgroup output tile size (MN count) and adding FP4 swizzling.
Thanks for getting all this data!
Write up to explain state
CC: @kuhar