iree icon indicating copy to clipboard operation
iree copied to clipboard

Work breakdown for MXFP enablement

Open krzysz00 opened this issue 7 months ago • 1 comments

This ticket is meant to track the work needed to get mxfp4 enabled in IREE, including the current state of the unknown unknowns. Please feel free to edit it or comment on it if you see anything missing.

  • [ ] Model-level work. Overall, there's uncertainty about the motivating usecase, the desired benchmark numbers, etc.
    • [ ] Identify or create a model or other input that will be using mxfp4
    • [ ] Determine how this will translate to an IREE ingress format
    • [ ] Create any high-level tensor ops needed to represent scaled matrix multiply. This could take the form of a linalg.generic, which likely depends on creating arith.scaled_extf and arith.scaled_truncf below.
    • [ ] If any model-running / quantization infrastructure needs to be updated to work with scale tensors, that goes here.
  • [ ] Dispatch creation work. This is a section whose triviality I (@krzysz00) am rather unclear on.
    • [ ] Ensure scaled matrix multiplies are correctly formed into dispatches
    • [ ] Ensure that scale broadcasts don't materialize into copies
    • [ ] Ensure we have a high-level representation we can consistently pattern-match on
    • [x] No need for HAL changes, there're implicit bitcast semantics on dispatch tensor load etc.
  • [ ] High-level codegen work.
    • [ ] Recognize scaled matrix multiply dispatches and rewrite them to something.
    • [ ] Define something that represents scaled matrix multiplies. There are various competing proposals here. a) Extend iree_gpu.multi_mma with optional scale parameters and start plumbing through through (this is, as far as we know, the Triton approach) b) Move iree_gpu.multi_mma into a variadic iree_codegen.[names_hard] to represent the general concept of inner-tiled operations, and make a new interface that's a parent of the MultiMmaInterface. Then, fail anything that isn't a matmul in all the rewrites and slowly relax that. This is probably the best long-term solution, but is the slowest. c) Something with Wave???
    • [ ] Ensure whatever we defined lowers to amdgpu.scaling_mfma
    • [ ] Ensure that scaling quantization/dequantization get lowered to the relevant intrinsics / are vectorized appropriately. (May not need doing depending on what the models look like)
  • [ ] Low-level codegen work. This is all IREE-independent work needed to enable calling the instructions we think we'll need to generate in these models.
    • [x] Add MFMA intrinsic wrapper ( https://github.com/llvm/llvm-project/pull/137498 )
    • [ ] Conversions (needed for quantizing output)
      • [x] Add ROCDL wrappers for fp4 (and fp6) conversion wrappers ( https://github.com/llvm/llvm-projects/pull/140801 )
      • [ ] Add AMDGPU wrappers for these conversion operations ( https://github.com/iree-org/iree/issues/20890 ) (partially in progress).
    • [ ] Add arith.scaling_extf and arith.scaling_truncf ( https://github.com/llvm/llvm-project/issues/138207 ). Umamg from rocMLIR's started work on this, we might need to pitch in.
    • [ ] Add ArithToAMDGPU logic to handle translating arith.scaling_extf and arith.scaling_truncf int those intrinsic wrapping ops. Design-wise, this isn't too difficult (it mainly requires unrolling these scaling ops until the "scale" argument is a broadcast or splat vector, and then you can lower them somewhat like how the existing round-to-fp8 ops are handled) but it still needs doing.
    • [ ] Ensure that sub-byte load/store support (upstream's EmulateNarrowTypes, I think) works with fp4. fp6 is out of scope here, but we'll at least want to think about it (@lialan if you have the bandwidth here)

krzysz00 avatar May 28 '25 21:05 krzysz00

suggested by team:

  • Add unpacking of fp4 values to fp32 for debugging use with iree-run-module

Muzammiluddin-Syed-ECE avatar May 29 '25 15:05 Muzammiluddin-Syed-ECE

Closing because I think we have checked all the boxes here as far as MXFP 1.0

krzysz00 avatar Aug 13 '25 21:08 krzysz00