iree
iree copied to clipboard
Work breakdown for MXFP enablement
This ticket is meant to track the work needed to get mxfp4 enabled in IREE, including the current state of the unknown unknowns. Please feel free to edit it or comment on it if you see anything missing.
- [ ] Model-level work. Overall, there's uncertainty about the motivating usecase, the desired benchmark numbers, etc.
- [ ] Identify or create a model or other input that will be using mxfp4
- [ ] Determine how this will translate to an IREE ingress format
- [ ] Create any high-level tensor ops needed to represent scaled matrix multiply. This could take the form of a
linalg.generic, which likely depends on creatingarith.scaled_extfandarith.scaled_truncfbelow. - [ ] If any model-running / quantization infrastructure needs to be updated to work with scale tensors, that goes here.
- [ ] Dispatch creation work. This is a section whose triviality I (@krzysz00) am rather unclear on.
- [ ] Ensure scaled matrix multiplies are correctly formed into dispatches
- [ ] Ensure that scale broadcasts don't materialize into copies
- [ ] Ensure we have a high-level representation we can consistently pattern-match on
- [x] No need for HAL changes, there're implicit bitcast semantics on dispatch tensor load etc.
- [ ] High-level codegen work.
- [ ] Recognize scaled matrix multiply dispatches and rewrite them to something.
- [ ] Define something that represents scaled matrix multiplies. There are various competing proposals here.
a) Extend
iree_gpu.multi_mmawith optional scale parameters and start plumbing through through (this is, as far as we know, the Triton approach) b) Moveiree_gpu.multi_mmainto a variadiciree_codegen.[names_hard]to represent the general concept of inner-tiled operations, and make a new interface that's a parent of theMultiMmaInterface. Then, fail anything that isn't a matmul in all the rewrites and slowly relax that. This is probably the best long-term solution, but is the slowest. c) Something with Wave??? - [ ] Ensure whatever we defined lowers to
amdgpu.scaling_mfma - [ ] Ensure that scaling quantization/dequantization get lowered to the relevant intrinsics / are vectorized appropriately. (May not need doing depending on what the models look like)
- [ ] Low-level codegen work. This is all IREE-independent work needed to enable calling the instructions we think we'll need to generate in these models.
- [x] Add MFMA intrinsic wrapper ( https://github.com/llvm/llvm-project/pull/137498 )
- [ ] Conversions (needed for quantizing output)
- [x] Add ROCDL wrappers for fp4 (and fp6) conversion wrappers ( https://github.com/llvm/llvm-projects/pull/140801 )
- [ ] Add AMDGPU wrappers for these conversion operations ( https://github.com/iree-org/iree/issues/20890 ) (partially in progress).
- [ ] Add
arith.scaling_extfandarith.scaling_truncf( https://github.com/llvm/llvm-project/issues/138207 ). Umamg from rocMLIR's started work on this, we might need to pitch in. - [ ] Add
ArithToAMDGPUlogic to handle translatingarith.scaling_extfandarith.scaling_truncfint those intrinsic wrapping ops. Design-wise, this isn't too difficult (it mainly requires unrolling these scaling ops until the "scale" argument is a broadcast or splat vector, and then you can lower them somewhat like how the existing round-to-fp8 ops are handled) but it still needs doing. - [ ] Ensure that sub-byte load/store support (upstream's EmulateNarrowTypes, I think) works with fp4. fp6 is out of scope here, but we'll at least want to think about it (@lialan if you have the bandwidth here)
suggested by team:
- Add unpacking of fp4 values to fp32 for debugging use with
iree-run-module
Closing because I think we have checked all the boxes here as far as MXFP 1.0