iree icon indicating copy to clipboard operation
iree copied to clipboard

[gfx950][mxfp4] Avoid bank conflicts on matrix inputs

Open krzysz00 opened this issue 2 months ago • 7 comments

Analysis of traces (thanks @sebvince) has shown that the mxfp4 kernel spends a substantial amount of time stalled at LDS bank conflicts on the matrix inputs - that is, matrices A and B (as opposed to scale inputs, which are a separate issue).

As with other matmuls, swizzles are needed in order to eliminate these conflicts.

We've got a sense that this'll substantially improve performance (1.25x, say, maybe more?) but will need to validate that.

Notes on resolution: Discussion with @sebvince indicated that he'd done experiments on these same bank conflicts in the Wave kernel and that the traditional padding approach was not feasible, and that an XOR-based swizzle (or perhaps the row rotation swizzle) was needed. He linked the following files

https://github.com/sebvince/test_bank_conflict/blob/main/kernel_f32_async_hint.mlir and https://github.com/sebvince/test_bank_conflict/blob/main/kernel_f32_async_ref.mlir test_bank_conflict/kernel_f32_async_ref.mlir

@qedawkins What's a good way to quickly apply these swizzling strategies in IREE so we can get empirical results for how well bank conflicts are avoided / what the perf gains here are?

Misc. notes: To check bank conflict in ATT : rocprofv3 --att-perfcounter-ctrl 3 --att-perfcounters "SQ_LDS_BANK_CONFLICT" -- ./your-application

krzysz00 avatar Oct 09 '25 18:10 krzysz00

Ok, so, I tried a simple experiment, which made things get slower. I suspect that that's me holding it wrong (haven't checked traces yet, but)

The patch

diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 4bd0012a8a..c85981fd42 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -12,6 +12,7 @@
 #include "iree/compiler/Codegen/Common/CombineLayoutTransformation.h"
 #include "iree/compiler/Codegen/Common/GPU/Passes.h"
 #include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
 #include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
 #include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"
 #include "iree/compiler/Codegen/LLVMGPU/Passes.h"
@@ -359,6 +360,15 @@ static FailureOr<Value> gpuRequireMemSpaceAllocationFn(OpBuilder &builder,
   allocType =
       MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
                       AffineMap(), workgroupSpace);
+  if (isa<Float4E2M1FNType>(memRefType.getElementType())) {
+    auto flatAllocType = MemRefType::get(ArrayRef<int64_t>{memRefType.getNumElements()}, memRefType.getElementType(), AffineMap(), workgroupSpace);
+    Value flatAlloc = memref::AllocOp::create(builder, loc, flatAllocType);
+    Value swizzled = iree_compiler::IREE::Codegen::SwizzleHintOp::create(builder, loc, flatAlloc,
+      iree_compiler::IREE::Codegen::XORShuffleAttr::get(builder.getContext(), 256, 32, 256, 1));
+    ReassociationIndices reassoc = llvm::to_vector(llvm::seq(allocType.getRank()));
+    Value expanded = memref::ExpandShapeOp::create(builder, loc, allocType.getShape(), swizzled, {reassoc});
+    return expanded;
+  }
   return memref::AllocOp::create(builder, loc, allocType, dynamicSizes)
       .getResult();
 }

The input

MLIR input file for this reproducer
//<dim> 8192 </dim>
//<dim> 8192 </dim>
//<dim> 256 </dim>
//<dim> 32 </dim>
//<input>tensor<8192x256x32xi8></input>
//<input>tensor<8192x256xi8></input>
//<input>tensor<8192x256x32xi8></input>
//<input>tensor<8192x256xi8></input>

!lhs = f4E2M1FN
!rhs = f4E2M1FN
!scale_ty = f8E8M0FNU

!A = tensor<8192x256x32xi8>
!A_i4 = tensor<8192x256x32xi4>
!B = tensor<8192x256x32xi8>
!B_i4 = tensor<8192x256x32xi4>
!A_fp4 = tensor<8192x256x32x!lhs>
!B_fp4 = tensor<8192x256x32x!rhs>

!A_scales = tensor<8192x256x!scale_ty>
!B_scales = tensor<8192x256x!scale_ty>
!A_s = tensor<8192x256xi8>
!B_s = tensor<8192x256xi8>

!C = tensor<8192x8192xf32>

#lhs_map = affine_map<(M, N, Ko, Kb) -> (M, Ko, Kb)>
#rhs_map = affine_map<(M, N, Ko, Kb) -> (N, Ko, Kb)>
#scale_m = affine_map<(M, N, Ko, Kb) -> (M, Ko)>
#scale_n = affine_map<(M, N, Ko, Kb) -> (N, Ko)>
#out_map = affine_map<(M, N, Ko, Kb) -> (M, N)>
func.func @scaled_matmul(%lhs : !A, %lhs_scales : !A_s, %rhs : !B, %rhs_scales : !B_s) -> !C {
  %A_i4 = arith.trunci %lhs : !A to !A_i4
  %B_i4 = arith.trunci %rhs : !B to !B_i4
  %A_scales = arith.bitcast %lhs_scales : !A_s to !A_scales
  %B_scales = arith.bitcast %rhs_scales : !B_s to !B_scales
  %A = arith.bitcast %A_i4 : !A_i4 to !A_fp4
  %B = arith.bitcast %B_i4 : !B_i4 to !B_fp4
  %cst = arith.constant 0.000000e+00 : f32
  %empty = tensor.empty() : !C
  %C = linalg.fill ins(%cst : f32) outs(%empty : !C) -> !C

  %D = linalg.generic {
    indexing_maps = [#lhs_map, #rhs_map, #scale_m, #scale_n, #out_map],
    iterator_types = ["parallel", "parallel", "reduction", "reduction"]
  } ins(%A, %B, %A_scales, %B_scales : !A_fp4, !B_fp4, !A_scales, !B_scales) outs(%C : !C) {
  ^bb0(%a: !lhs, %b: !rhs, %a_scale: !scale_ty, %b_scale: !scale_ty, %out: f32):
    %1 = arith.scaling_extf %a, %a_scale : !lhs, !scale_ty to f32
    %2 = arith.scaling_extf %b, %b_scale : !rhs, !scale_ty to f32
    %3 = arith.mulf %1, %2 : f32
    %4 = arith.addf %out, %3 : f32
    linalg.yield %4 : f32
  } -> !C
  return %D : !C
}

The compile command

iree-compile  --iree-opt-level=O3 --iree-hip-waves-per-eu=2 --iree-dispatch-creation-test-set-scaled-matmul-encodings=false xxx.mlir -o xxx.vmfb --iree-hal-target-device=hip --iree-hip-target=gfx950 

This has, however, given me the side observation that we're missing an apply/delinearize cancelation that might not be helping.

krzysz00 avatar Oct 10 '25 21:10 krzysz00

Looking at the IR, it seems that the LDS tile is 32x128.

%14 = iree_codegen.swizzle_hint %alloc[#iree_codegen.xor_shuffle<256, 32, 256, 1>] : memref<131072xf4E2M1FN, #gpu.address_space<workgroup>>
%expand_shape_0 = memref.expand_shape %14 [[0, 1, 2]] output_shape [32, 128, 32] : memref<131072xf4E2M1FN, #gpu.address_space<workgroup>> into memref<32x128x32xf4E2M1FN, #gpu.address_space<workgroup>>

Try using row_width=4096 (32 * 128) instead of 256.

XORShuffleAttr::get(builder.getContext(), 4096, 32, 4096, 1));

sebvince avatar Oct 13 '25 08:10 sebvince

@sebvince Thanks for spotting the obvious mistake!

Now, to give times

Without xor:

BM_scaled_matmul/process_time/real_time_median       5.35 ms         5.36 ms           10 items_per_second=187.032/s
BM_scaled_matmul/process_time/real_time_stddev      0.002 ms        0.007 ms           10 items_per_second=0.0754851/s

With xor

BM_scaled_matmul/process_time/real_time_median       3.78 ms         3.80 ms           10 items_per_second=264.49/s
BM_scaled_matmul/process_time/real_time_stddev      0.003 ms        0.004 ms           10 items_per_second=0.216108/s

This is a 1.4x improvement!

The patch that yielded this

diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
index 4bd0012a8a..1e43604d53 100644
--- a/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
+++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
@@ -12,6 +12,7 @@
 #include "iree/compiler/Codegen/Common/CombineLayoutTransformation.h"
 #include "iree/compiler/Codegen/Common/GPU/Passes.h"
 #include "iree/compiler/Codegen/Common/Passes.h"
+#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
 #include "iree/compiler/Codegen/Dialect/GPU/Transforms/Passes.h"
 #include "iree/compiler/Codegen/Dialect/VectorExt/Transforms/Passes.h"
 #include "iree/compiler/Codegen/LLVMGPU/Passes.h"
@@ -359,6 +360,15 @@ static FailureOr<Value> gpuRequireMemSpaceAllocationFn(OpBuilder &builder,
   allocType =
       MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
                       AffineMap(), workgroupSpace);
+  if (isa<Float4E2M1FNType>(memRefType.getElementType())) {
+    auto flatAllocType = MemRefType::get(ArrayRef<int64_t>{memRefType.getNumElements()}, memRefType.getElementType(), AffineMap(), workgroupSpace);
+    Value flatAlloc = memref::AllocOp::create(builder, loc, flatAllocType);
+    Value swizzled = iree_compiler::IREE::Codegen::SwizzleHintOp::create(builder, loc, flatAlloc,
+                                                                         iree_compiler::IREE::Codegen::XORShuffleAttr::get(builder.getContext(), 4096, 32, int64_t(), int64_t()));
+    ReassociationIndices reassoc = llvm::to_vector(llvm::seq(allocType.getRank()));
+    Value expanded = memref::ExpandShapeOp::create(builder, loc, allocType.getShape(), swizzled, {reassoc});
+    return expanded;
+  }
   return memref::AllocOp::create(builder, loc, allocType, dynamicSizes)
       .getResult();
 }

So I think now the question will be how to productionize this. I think there's a bank conflict avoidance pass that ... adds these swizzle operations?

krzysz00 avatar Oct 13 '25 15:10 krzysz00

Plan:

  • Teach swizzle hints to take tensors
  • Make some swizzle attributes (perhaps with a wrapper) implement PromotionAttr so we can use existing promotion infrastructure and GPU config stuff to set them
  • Make sure reduction tiling can look through the swizzle hints.

krzysz00 avatar Oct 20 '25 19:10 krzysz00

Per @kuhar - shouldn't use the swizzles directly: too many knobs for tuner, interface, etc. Just have an enum and materialize the parameters later

krzysz00 avatar Oct 23 '25 17:10 krzysz00

After investigating various shapes and experiments what we've found works the best for large gemms (note that im not using the definition of large gemm that used in GPUMMAHeurstics) is a combination of the following optimizations:

  • FP4 swizzling with row width 256, access width 32
  • K Tile Element Count: 256
  • MN subgroup count: 64 (meaning each subgroup produces a 128x128 output tile)
  • Scales packing

This provides a 51% geomean improvement across several shapes checked:

512_512_256_16 (-0.88%)
1024_512_256_16 (-25%)
8192_512_256_16 
16384_512_256_16
53248_512_256_16
512_1024_8192_512
512_16384_8192_512
512_53248_8192_512
1024_1024_8192_512
512_16384_26624_1664
8192_1024_8192_512 (-10%)
16384_1024_8192_512
1024_16384_8192_512
53248_1024_8192_512
1024_53248_8192_512
1024_16384_26624_1664
8192_16384_8192_512
8192_53248_8192_512 (+108%)
16384_16384_8192_512
53248_16384_8192_512
16384_53248_8192_512 (+101%)
53248_53248_8192_512
8192_16384_26624_1664
16384_16384_26624_1664
53248_16384_26624_1664 (+112%)

This is the geomean improvement with each individual config disabled:

  • Baseline (everything enabled): +51%
  • FP4 swizzling: +33%
  • K tile element count: +49%
  • MN count: +31%
  • ScalePacking: +48%

This implies that for shapes of very large arithmetic intensity we get the greatest boost in performance from improving the subgroup output tile size (MN count) and adding FP4 swizzling.

Muzammiluddin-Syed-ECE avatar Dec 01 '25 20:12 Muzammiluddin-Syed-ECE

Thanks for getting all this data!

krzysz00 avatar Dec 01 '25 20:12 krzysz00

Write up to explain state

Muzammiluddin-Syed-ECE avatar Dec 15 '25 22:12 Muzammiluddin-Syed-ECE

CC: @kuhar

Muzammiluddin-Syed-ECE avatar Dec 15 '25 22:12 Muzammiluddin-Syed-ECE