iree icon indicating copy to clipboard operation
iree copied to clipboard

[GPU][DT] Implement mxfp4 MLIR ukernel

Open jtuyls opened this issue 3 months ago • 20 comments

jtuyls avatar Sep 11 '25 07:09 jtuyls

I'd like to get a status update for mxfp4 MLIR ukernel. What is the current status and remaining work?

(I'm just check in the status. We can discuss more tomorrow, but we still need to update the issue.)

hanhanW avatar Nov 18 '25 01:11 hanhanW

E2E Results: Scaled Matmul (FP4) - MxNxKxKB=10000x16384x512x32

Performance Summary

Configuration Performance
no dt 8.08ms
no dt + uk Not ready
dt + no uk 4.70ms
dt + uk 4.73ms

Waveform Analysis

dt + no uk: dt + no uk waveform

dt + uk: dt + uk waveform

Analysis

Why is "no dt + uk" not ready?

The implementation exists (https://github.com/iree-org/iree/compare/main...qedawkins:iree:fp4_pingpong#diff-185cc416a5fbb75606542bb5cd6df129015afb7d593eed4c5c876b47191bdcb6R42), but performs poorly for the given shape. The original plan was to start with the existing "no dt + uk" implementation (assuming it is already performing well), then adapt it to "dt + uk". However, that's not the case...So I have to some further tuning on tile sizes and barriers, and currently I only have time to apply these optimizations for dt version only.

Why does "dt + uk" have the same performance as "dt + no uk"?

The direct codegen seems to perform well in pipeline behavior. It even uses three waves per execution unit, which could provide some advantage in scheduling? So far, both implementations are bottlenecked by global memory reads. Specifically the yellow blocks (VMEM instructions) and the empty regions (VMEM waits).

Next Steps on "dt + uk"

  • Reduce memory indexing cost (light green blocks interleaving with yellow/orange). This might be achievable by collapsing some dimensions, which previously worked well for the f8 cases.
  • Reduce wait time (empty resions between activities). It’s unclear why these gaps are so significant. I suspect this is related to low cache hit rates, as scaled matmul sees only ~30% L1 hit. I saw the example of adding an {aux = 3} attribute (https://github.com/iree-org/iree/commit/06e6089ceb95cfbeb2e4e7bf33036c0e3eeba47e), but I’m not entirely sure what that attribute actually controls, and in my experiments it made performance worse.

Yu-Zhewen avatar Nov 18 '25 12:11 Yu-Zhewen

Awesome progress @Yu-Zhewen ! FYI @sebvince .

bjacob avatar Nov 18 '25 14:11 bjacob

Awesome ! We should catch up to discuss this :) . Happy to look at ATT trace if you have one.

for aux = 3 -> it sets sc0 sc1 meaning we bypass L1/L2 for global-to-lds loads. Are you doing MFMAs just one every 2 iterations ? I don't see any dark green on both left and right side on the screenshot.

sebvince avatar Nov 18 '25 15:11 sebvince

for aux = 3 -> it sets sc0 sc1 meaning we bypass L1/L2 for global-to-lds loads.

I see, thanks for sharing the information!

Are you doing MFMAs just one every 2 iterations ?

Yeah, as I can't afford a larger tile size here, and if I split MFMA into 2 stages:

(LDS Stage 0, MFMA Stage 0) -> (LDS Stage 1, MFMA Stage 1)
(Global Mem Stage 0) -> (Global Mem Stage 1)

I don't think it will make a difference, since global memory is bottlenecking anyway?

This is quite different with the case of fp8 on gfx942, where LDS is the bottleneck instead and we do

(All Global mem, LDS Stage 0) -> (MFMA Stage 0)
(Last iteration MFMA Stage 1) -> (LDS Stage 1)
Image

Yu-Zhewen avatar Nov 18 '25 15:11 Yu-Zhewen

After discussing with @sebvince, we might need to relax intrinsicsK

https://github.com/iree-org/iree/blob/4bce5a2f3d21fa4c5594f1f51aa914b9aa4f80eb/compiler/src/iree/compiler/Codegen/ExternalInterfaces/GPUEncodingExternalModels.cpp#L208

and instead increase the M and N tile sizes to shift the kernel toward being more compute-bound rather than the current memory-bound behavior.

Yu-Zhewen avatar Nov 18 '25 17:11 Yu-Zhewen

This is awesome, @Yu-Zhewen !

Does the numbers for dt include relayout overheads?

hanhanW avatar Nov 18 '25 23:11 hanhanW

Does the numbers for dt include relayout overheads?

Yes, it's the e2e result using iree-benchmark-module (with 5 dispatches in total for dt)

Yu-Zhewen avatar Nov 18 '25 23:11 Yu-Zhewen

cc @krzysz00 @Muzammiluddin-Syed-ECE who used to work on scaled matmul things.

hanhanW avatar Nov 18 '25 23:11 hanhanW

(Ping noted, someone grab a chunk of my time tomorrow or Thursday to go over this?)

(Also, is there a branch?)

krzysz00 avatar Nov 19 '25 01:11 krzysz00

Also, is there a branch?

https://github.com/iree-org/iree/compare/main...Yu-Zhewen:iree:dt_fp4_ukernel

And the ukernel used is

https://github.com/Yu-Zhewen/iree/blob/59161ca901a1e4aa364f5df4ee4e7784e7d3eace/compiler/plugins/target/ROCM/builtins/mlir_ukernel/iree_uk_amdgpu_dt_scaled_matmul_f4E2M1FN_m64_n128_k512.mlir

(Please ignore the comments in the code, as they are completely out-of-date)

@krzysz00, messaged you on slack

Yu-Zhewen avatar Nov 19 '25 09:11 Yu-Zhewen

Update on Tiling Size Tuning

Previous Configuration: M=64, N=128, K=512

dt_m64_n128_k512 (previously referred to as dt + no uk)

Operation Time Percentage
scaled_contraction_f4_f4_f8_f8_f32_from_i8_dispatch 2.26ms 59.19%
encoding_0_encode_10000x512x32xf4E2M1FN_to_10000x512x32xf4E2M1FN 1.22 ms 31.94%
encoding_1_encode_16384x512x32xf4E2M1FN_to_16384x512x32xf4E2M1FN 234.04 µs 6.12%
encoding_3_encode_16384x512xf8E8M0FNU_to_16384x512xf8E8M0FNU 32.26 µs 0.84%
encoding_2_encode_10000x512xf8E8M0FNU_to_10000x512xf8E8M0FNU 21.27 µs 0.56%

dt_uk_m64_n128_k512 (previously referred to as dt + uk)

Operation Time Percentage
scaled_contraction_f4_f4_f8_f8_f32_from_i8_dispatch 2.3ms 60.55%
encoding_0_encode_10000x512x32xf4E2M1FN_to_10000x512x32xf4E2M1FN 1.16 ms 30.65%
encoding_1_encode_16384x512x32xf4E2M1FN_to_16384x512x32xf4E2M1FN 233.83 µs 6.16%
encoding_3_encode_16384x512xf8E8M0FNU_to_16384x512xf8E8M0FNU 32.29 µs 0.85%
encoding_2_encode_10000x512xf8E8M0FNU_to_10000x512xf8E8M0FNU 21.45 µs 0.57%

Improved Configuration: M=128, N=256, K=256

dt_m128_n256_k256

Operation Time Percentage
encoding_0_encode_10000x512x32xf4E2M1FN_to_10000x512x32xf4E2M1FN 2.48 ms 51.47%
scaled_contraction_f4_f4_f8_f8_f32_from_i8_dispatch 2.02 ms 41.91%
encoding_1_encode_16384x512x32xf4E2M1FN_to_16384x512x32xf4E2M1FN 231.52 µs 4.80%
encoding_3_encode_16384x512xf8E8M0FNU_to_16384x512xf8E8M0FNU 32.56 µs 0.68%
encoding_2_encode_10000x512xf8E8M0FNU_to_10000x512xf8E8M0FNU 23.44 µs 0.49%

dt_uk_m128_n256_k256

Operation Time Percentage
encoding_0_encode_10000x512x32xf4E2M1FN_to_10000x512x32xf4E2M1FN 2.52 ms 58.06%
scaled_contraction_f4_f4_f8_f8_f32_from_i8_dispatch 1.5 ms 34.54%
encoding_1_encode_16384x512x32xf4E2M1FN_to_16384x512x32xf4E2M1FN 231.77 µs 5.34%
encoding_3_encode_16384x512xf8E8M0FNU_to_16384x512xf8E8M0FNU 32.72 µs 0.75%
encoding_2_encode_10000x512xf8E8M0FNU_to_10000x512xf8E8M0FNU 24.88 µs 0.57%

Observations

  • Increasing M and N tile sizes while decreasing K reduces the scaled matmul dispatch latency significantly for both the uk variant (2.3ms → 1.5ms, -34.8%) and for the no uk variant (2.26ms → 2.02ms, -10.6%). This improvement comes from being less memory-bounded.

  • The improved dispatch latency comes at the cost of increased LHS encoding time (from ~1.2ms to ~2.5ms). This overhead could potentially be mitigated through fusion when targeting full model compilation.

  • With the previous tile sizes (64x128x512), the uk showed minimal benefit (2.26ms vs 2.3ms). After tuning, the uk variant demonstrates clear advantages with dispatch latency of 1.5ms compared to 2.02ms for the no-uk variant, a 25.7% improvement.

  • The reduced K tile size makes packing scales more challenging. To address this, I enabled interleaving_intrinsics_m for lhs_scale and interleaving_intrinsics_n for rhs_scale (concepts introduced in #22626). However, the swizzle pattern I used differs from the original implementation.

    Example:

    • Original: [["CrossIntrinsicM", 4], ["CrossThreadM", 16]], [["CrossIntrinsicK", 2], ["CrossThreadK", 4]], permutation = [3, 0, 1, 2]
    • My version: permutation = [3, 1, 0, 2]

    Should i just refresh the definition of interleaving_intrinsics_m/n or introduce a new flag here? @bjacob

Yu-Zhewen avatar Dec 02 '25 15:12 Yu-Zhewen

As a reference, I managed to run the same shape with the assembly kernel using https://github.com/nod-ai/fp4-benchmark

Assembly kernel e2e latency: 3.61ms

If we ignore all the encoding cost for now, our 1.5ms is already better.

@jtuyls @bjacob @Max191 @Muzammiluddin-Syed-ECE

Yu-Zhewen avatar Dec 02 '25 17:12 Yu-Zhewen

As a reference, I managed to run the same shape with the assembly kernel using https://github.com/nod-ai/fp4-benchmark

Assembly kernel e2e latency: 3.61ms

If we ignore all the encoding cost for now, our 1.5ms is already better.

@jtuyls @bjacob @Max191 @Muzammiluddin-Syed-ECE

Cool ! 3.61ms seems quite low though (<1.5 PFLOPs). I thought it was running at more than 5 PFLOPs.

sebvince avatar Dec 02 '25 17:12 sebvince

Cool ! 3.61ms seems quite low though (<1.5 PFLOPs). I thought it was running at more than 5 PFLOPs.

This specific assembly kernel was extracted from https://github.com/nod-ai/iree-model-benchmark/blob/dd2a0a1965c144566ea43ce88e909ee77688283b/llama3/base_ir/405b_fp4_asm.mlir which achieves the 5 PFLOPs at shapes much larger than this, for example (16384, 16384, 16384). I can try to find the assembly kernel optimized for smaller shapes in AITER.

Muzammiluddin-Syed-ECE avatar Dec 02 '25 18:12 Muzammiluddin-Syed-ECE

As a reference, I managed to run the same shape with the assembly kernel using https://github.com/nod-ai/fp4-benchmark

I've neglected to update the README.md file and make sure the flags have helpful descriptions😅 so I'm curious what flags you're using to run this.

One that I'd recommend if you want to measure against some larger shapes would be --default-shapes (takes ~15 minutes to complete), and if you want singular dispatch performance --tracy -g, so the entire command might be:

cd <fp4-benchmark>
python benchmark.py --asm --default-shapes --tracy -g

Muzammiluddin-Syed-ECE avatar Dec 02 '25 19:12 Muzammiluddin-Syed-ECE

Cool ! 3.61ms seems quite low though (<1.5 PFLOPs). I thought it was running at more than 5 PFLOPs.

This specific assembly kernel was extracted from https://github.com/nod-ai/iree-model-benchmark/blob/dd2a0a1965c144566ea43ce88e909ee77688283b/llama3/base_ir/405b_fp4_asm.mlir which achieves the 5 PFLOPs at shapes much larger than this, for example (16384, 16384, 16384). I can try to see if I can find what the assembly kernel does at these smaller shapes by looking into AITER.

Just to clarify, the shape I am targeting here is (10000, 16384, 16384). For M=64, N=128, K=512 i was actually referring to the tile sizes.

Also I was using the following commands to generate and benchmark the asm:

python3 benchmark.py --auto --artifacts --asm --m 10000 --n 16384 --k 16384

iree-benchmark-module --device=hip://0 --device_allocator=caching --hip_use_streams=true --benchmark_repetitions=10 --module=10000_16384_8192_512_asm_gemm.vmfb --function=assembly_matmul --input=@10000_8192_i8.npy --input=@16384_8192_i8.npy --input=@10000_512_i8.npy --input=@16384_512_i8.npy 

I wasn't setting --iree-hal-benchmark-dispatch-repeat-count=100, --batch_size=100, just to stay consistent with my other experiments, and that might cause some differences

Yu-Zhewen avatar Dec 02 '25 19:12 Yu-Zhewen

Just to clarify, the shape I am targeting here is (10000, 16384, 16384). For M=64, N=128, K=512 i was actually referring to the tile sizes.

Also I was using the following commands to generate and benchmark the asm:

python3 benchmark.py --auto --artifacts --asm --m 10000 --n 16384 --k 16384

iree-benchmark-module --device=hip://0 --device_allocator=caching --hip_use_streams=true --benchmark_repetitions=10 --module=10000_16384_8192_512_asm_gemm.vmfb --function=assembly_matmul --input=@10000_8192_i8.npy --input=@16384_8192_i8.npy --input=@10000_512_i8.npy --input=@16384_512_i8.npy 

Ah yes my mistake that looks good.

I misunderstood the reason why we're seeing such low performance from the asm kernel. The 5 PFLOPs is achieved when measuring only the time of the fp4 gemm dispatch.

We can verify this by running

python3 benchmark.py --asm --m 10000 --n 16384 --k 16384 --tracy -g 

The .csv files exported by tracy show that the asm kernel gemm dispatch takes 1.05 ms which is a throughput of 5.1 PFLOPs.

edit: So, the 5 PFLOPs throughput is not expected in an e2e performance test like we're doing here cc: @sebvince

Muzammiluddin-Syed-ECE avatar Dec 02 '25 19:12 Muzammiluddin-Syed-ECE

Ah right, thank! I am getting 1.17ms from my side. I wasn't aware the asm version also has that much overhead (3.61ms -1.05 or 1.17ms) outside the actual matmul dispatch. So the proper comparison should really be asm (1.05 or 1.17 ms) vs. dt-uk (1.5 ms) instead.

Yu-Zhewen avatar Dec 02 '25 19:12 Yu-Zhewen

That makes more sense. Thanks for the clarification @Yu-Zhewen @Muzammiluddin-Syed-ECE !

sebvince avatar Dec 03 '25 09:12 sebvince