[GPU][DT] Implement mxfp4 MLIR ukernel
I'd like to get a status update for mxfp4 MLIR ukernel. What is the current status and remaining work?
(I'm just check in the status. We can discuss more tomorrow, but we still need to update the issue.)
E2E Results: Scaled Matmul (FP4) - MxNxKxKB=10000x16384x512x32
Performance Summary
| Configuration | Performance |
|---|---|
| no dt | 8.08ms |
| no dt + uk | Not ready |
| dt + no uk | 4.70ms |
| dt + uk | 4.73ms |
Waveform Analysis
dt + no uk:
dt + uk:
Analysis
Why is "no dt + uk" not ready?
The implementation exists (https://github.com/iree-org/iree/compare/main...qedawkins:iree:fp4_pingpong#diff-185cc416a5fbb75606542bb5cd6df129015afb7d593eed4c5c876b47191bdcb6R42), but performs poorly for the given shape. The original plan was to start with the existing "no dt + uk" implementation (assuming it is already performing well), then adapt it to "dt + uk". However, that's not the case...So I have to some further tuning on tile sizes and barriers, and currently I only have time to apply these optimizations for dt version only.
Why does "dt + uk" have the same performance as "dt + no uk"?
The direct codegen seems to perform well in pipeline behavior. It even uses three waves per execution unit, which could provide some advantage in scheduling? So far, both implementations are bottlenecked by global memory reads. Specifically the yellow blocks (VMEM instructions) and the empty regions (VMEM waits).
Next Steps on "dt + uk"
- Reduce memory indexing cost (light green blocks interleaving with yellow/orange). This might be achievable by collapsing some dimensions, which previously worked well for the f8 cases.
- Reduce wait time (empty resions between activities). It’s unclear why these gaps are so significant. I suspect this is related to low cache hit rates, as scaled matmul sees only ~30% L1 hit. I saw the example of adding an {aux = 3} attribute (https://github.com/iree-org/iree/commit/06e6089ceb95cfbeb2e4e7bf33036c0e3eeba47e), but I’m not entirely sure what that attribute actually controls, and in my experiments it made performance worse.
Awesome progress @Yu-Zhewen ! FYI @sebvince .
Awesome ! We should catch up to discuss this :) . Happy to look at ATT trace if you have one.
for aux = 3 -> it sets sc0 sc1 meaning we bypass L1/L2 for global-to-lds loads.
Are you doing MFMAs just one every 2 iterations ? I don't see any dark green on both left and right side on the screenshot.
for aux = 3 -> it sets sc0 sc1 meaning we bypass L1/L2 for global-to-lds loads.
I see, thanks for sharing the information!
Are you doing MFMAs just one every 2 iterations ?
Yeah, as I can't afford a larger tile size here, and if I split MFMA into 2 stages:
(LDS Stage 0, MFMA Stage 0) -> (LDS Stage 1, MFMA Stage 1)
(Global Mem Stage 0) -> (Global Mem Stage 1)
I don't think it will make a difference, since global memory is bottlenecking anyway?
This is quite different with the case of fp8 on gfx942, where LDS is the bottleneck instead and we do
(All Global mem, LDS Stage 0) -> (MFMA Stage 0)
(Last iteration MFMA Stage 1) -> (LDS Stage 1)
After discussing with @sebvince, we might need to relax intrinsicsK
https://github.com/iree-org/iree/blob/4bce5a2f3d21fa4c5594f1f51aa914b9aa4f80eb/compiler/src/iree/compiler/Codegen/ExternalInterfaces/GPUEncodingExternalModels.cpp#L208
and instead increase the M and N tile sizes to shift the kernel toward being more compute-bound rather than the current memory-bound behavior.
This is awesome, @Yu-Zhewen !
Does the numbers for dt include relayout overheads?
Does the numbers for
dtinclude relayout overheads?
Yes, it's the e2e result using iree-benchmark-module (with 5 dispatches in total for dt)
cc @krzysz00 @Muzammiluddin-Syed-ECE who used to work on scaled matmul things.
(Ping noted, someone grab a chunk of my time tomorrow or Thursday to go over this?)
(Also, is there a branch?)
Also, is there a branch?
https://github.com/iree-org/iree/compare/main...Yu-Zhewen:iree:dt_fp4_ukernel
And the ukernel used is
https://github.com/Yu-Zhewen/iree/blob/59161ca901a1e4aa364f5df4ee4e7784e7d3eace/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_dt_scaled_matmul_f4E2M1FN_m64_n128_k512.mlir
(Please ignore the comments in the code, as they are completely out-of-date)
@krzysz00, messaged you on slack
Update on Tiling Size Tuning
Previous Configuration: M=64, N=128, K=512
dt_m64_n128_k512 (previously referred to as dt + no uk)
| Operation | Time | Percentage |
|---|---|---|
| scaled_contraction_f4_f4_f8_f8_f32_from_i8_dispatch | 2.26ms | 59.19% |
| encoding_0_encode_10000x512x32xf4E2M1FN_to_10000x512x32xf4E2M1FN | 1.22 ms | 31.94% |
| encoding_1_encode_16384x512x32xf4E2M1FN_to_16384x512x32xf4E2M1FN | 234.04 µs | 6.12% |
| encoding_3_encode_16384x512xf8E8M0FNU_to_16384x512xf8E8M0FNU | 32.26 µs | 0.84% |
| encoding_2_encode_10000x512xf8E8M0FNU_to_10000x512xf8E8M0FNU | 21.27 µs | 0.56% |
dt_uk_m64_n128_k512 (previously referred to as dt + uk)
| Operation | Time | Percentage |
|---|---|---|
| scaled_contraction_f4_f4_f8_f8_f32_from_i8_dispatch | 2.3ms | 60.55% |
| encoding_0_encode_10000x512x32xf4E2M1FN_to_10000x512x32xf4E2M1FN | 1.16 ms | 30.65% |
| encoding_1_encode_16384x512x32xf4E2M1FN_to_16384x512x32xf4E2M1FN | 233.83 µs | 6.16% |
| encoding_3_encode_16384x512xf8E8M0FNU_to_16384x512xf8E8M0FNU | 32.29 µs | 0.85% |
| encoding_2_encode_10000x512xf8E8M0FNU_to_10000x512xf8E8M0FNU | 21.45 µs | 0.57% |
Improved Configuration: M=128, N=256, K=256
dt_m128_n256_k256
| Operation | Time | Percentage |
|---|---|---|
| encoding_0_encode_10000x512x32xf4E2M1FN_to_10000x512x32xf4E2M1FN | 2.48 ms | 51.47% |
| scaled_contraction_f4_f4_f8_f8_f32_from_i8_dispatch | 2.02 ms | 41.91% |
| encoding_1_encode_16384x512x32xf4E2M1FN_to_16384x512x32xf4E2M1FN | 231.52 µs | 4.80% |
| encoding_3_encode_16384x512xf8E8M0FNU_to_16384x512xf8E8M0FNU | 32.56 µs | 0.68% |
| encoding_2_encode_10000x512xf8E8M0FNU_to_10000x512xf8E8M0FNU | 23.44 µs | 0.49% |
dt_uk_m128_n256_k256
| Operation | Time | Percentage |
|---|---|---|
| encoding_0_encode_10000x512x32xf4E2M1FN_to_10000x512x32xf4E2M1FN | 2.52 ms | 58.06% |
| scaled_contraction_f4_f4_f8_f8_f32_from_i8_dispatch | 1.5 ms | 34.54% |
| encoding_1_encode_16384x512x32xf4E2M1FN_to_16384x512x32xf4E2M1FN | 231.77 µs | 5.34% |
| encoding_3_encode_16384x512xf8E8M0FNU_to_16384x512xf8E8M0FNU | 32.72 µs | 0.75% |
| encoding_2_encode_10000x512xf8E8M0FNU_to_10000x512xf8E8M0FNU | 24.88 µs | 0.57% |
Observations
-
Increasing M and N tile sizes while decreasing K reduces the scaled matmul dispatch latency significantly for both the uk variant (2.3ms → 1.5ms, -34.8%) and for the no uk variant (2.26ms → 2.02ms, -10.6%). This improvement comes from being less memory-bounded.
-
The improved dispatch latency comes at the cost of increased LHS encoding time (from ~1.2ms to ~2.5ms). This overhead could potentially be mitigated through fusion when targeting full model compilation.
-
With the previous tile sizes (64x128x512), the uk showed minimal benefit (2.26ms vs 2.3ms). After tuning, the uk variant demonstrates clear advantages with dispatch latency of 1.5ms compared to 2.02ms for the no-uk variant, a 25.7% improvement.
-
The reduced K tile size makes packing scales more challenging. To address this, I enabled
interleaving_intrinsics_mfor lhs_scale andinterleaving_intrinsics_nfor rhs_scale (concepts introduced in #22626). However, the swizzle pattern I used differs from the original implementation.Example:
- Original:
[["CrossIntrinsicM", 4], ["CrossThreadM", 16]], [["CrossIntrinsicK", 2], ["CrossThreadK", 4]], permutation = [3, 0, 1, 2] - My version:
permutation = [3, 1, 0, 2]
Should i just refresh the definition of
interleaving_intrinsics_m/nor introduce a new flag here? @bjacob - Original:
As a reference, I managed to run the same shape with the assembly kernel using https://github.com/nod-ai/fp4-benchmark
Assembly kernel e2e latency: 3.61ms
If we ignore all the encoding cost for now, our 1.5ms is already better.
@jtuyls @bjacob @Max191 @Muzammiluddin-Syed-ECE
As a reference, I managed to run the same shape with the assembly kernel using https://github.com/nod-ai/fp4-benchmark
Assembly kernel e2e latency: 3.61ms
If we ignore all the encoding cost for now, our 1.5ms is already better.
Cool ! 3.61ms seems quite low though (<1.5 PFLOPs). I thought it was running at more than 5 PFLOPs.
Cool ! 3.61ms seems quite low though (<1.5 PFLOPs). I thought it was running at more than 5 PFLOPs.
This specific assembly kernel was extracted from https://github.com/nod-ai/iree-model-benchmark/blob/dd2a0a1965c144566ea43ce88e909ee77688283b/llama3/base_ir/405b_fp4_asm.mlir which achieves the 5 PFLOPs at shapes much larger than this, for example (16384, 16384, 16384). I can try to find the assembly kernel optimized for smaller shapes in AITER.
As a reference, I managed to run the same shape with the assembly kernel using https://github.com/nod-ai/fp4-benchmark
I've neglected to update the README.md file and make sure the flags have helpful descriptions😅 so I'm curious what flags you're using to run this.
One that I'd recommend if you want to measure against some larger shapes would be --default-shapes (takes ~15 minutes to complete), and if you want singular dispatch performance --tracy -g, so the entire command might be:
cd <fp4-benchmark>
python benchmark.py --asm --default-shapes --tracy -g
Cool ! 3.61ms seems quite low though (<1.5 PFLOPs). I thought it was running at more than 5 PFLOPs.
This specific assembly kernel was extracted from https://github.com/nod-ai/iree-model-benchmark/blob/dd2a0a1965c144566ea43ce88e909ee77688283b/llama3/base_ir/405b_fp4_asm.mlir which achieves the 5 PFLOPs at shapes much larger than this, for example (16384, 16384, 16384). I can try to see if I can find what the assembly kernel does at these smaller shapes by looking into AITER.
Just to clarify, the shape I am targeting here is (10000, 16384, 16384). For M=64, N=128, K=512 i was actually referring to the tile sizes.
Also I was using the following commands to generate and benchmark the asm:
python3 benchmark.py --auto --artifacts --asm --m 10000 --n 16384 --k 16384
iree-benchmark-module --device=hip://0 --device_allocator=caching --hip_use_streams=true --benchmark_repetitions=10 --module=10000_16384_8192_512_asm_gemm.vmfb --function=assembly_matmul --input=@10000_8192_i8.npy --input=@16384_8192_i8.npy --input=@10000_512_i8.npy --input=@16384_512_i8.npy
I wasn't setting --iree-hal-benchmark-dispatch-repeat-count=100, --batch_size=100, just to stay consistent with my other experiments, and that might cause some differences
Just to clarify, the shape I am targeting here is (10000, 16384, 16384). For
M=64, N=128, K=512i was actually referring to the tile sizes.Also I was using the following commands to generate and benchmark the asm:
python3 benchmark.py --auto --artifacts --asm --m 10000 --n 16384 --k 16384 iree-benchmark-module --device=hip://0 --device_allocator=caching --hip_use_streams=true --benchmark_repetitions=10 --module=10000_16384_8192_512_asm_gemm.vmfb --function=assembly_matmul --input=@10000_8192_i8.npy --input=@16384_8192_i8.npy --input=@10000_512_i8.npy --input=@16384_512_i8.npy
Ah yes my mistake that looks good.
I misunderstood the reason why we're seeing such low performance from the asm kernel. The 5 PFLOPs is achieved when measuring only the time of the fp4 gemm dispatch.
We can verify this by running
python3 benchmark.py --asm --m 10000 --n 16384 --k 16384 --tracy -g
The .csv files exported by tracy show that the asm kernel gemm dispatch takes 1.05 ms which is a throughput of 5.1 PFLOPs.
edit: So, the 5 PFLOPs throughput is not expected in an e2e performance test like we're doing here cc: @sebvince
Ah right, thank! I am getting 1.17ms from my side. I wasn't aware the asm version also has that much overhead (3.61ms -1.05 or 1.17ms) outside the actual matmul dispatch. So the proper comparison should really be asm (1.05 or 1.17 ms) vs. dt-uk (1.5 ms) instead.
That makes more sense. Thanks for the clarification @Yu-Zhewen @Muzammiluddin-Syed-ECE !