ao
ao copied to clipboard
MXFP Inference Tracking Doc
trafficstars
MXFP Inference and Performance Tracking
Summary
This issue tracks performance and E2E integration of MXFP formats (MXFP8, MXFP4, NVFP4) on B200 and other devices.
Status Overview
| Component | Status | Notes |
|---|---|---|
| Dense Model Support | ✅ Done | Dense models are working E2E |
| MOE Support | 🟡 Not Started | Need to add support to MXFP8 to scaled_grouped_gemm |
| VLLM Integration | 🟡 In Progress | Works, performance inconsistencies |
| VLLM MXFP8 Performance | 🟡 Suboptimal | Currently ~11% slower than BF16 baseline |
| VLLM MXFP4 Performance | 🟡 SubOptimal | Comparable to BF16 baseline |
GEMM Kernels
| Format | Kernel Source | Performance Status | Notes |
|---|---|---|---|
| MXFP8 | CuBlas / ScaledMM | ✅ Optimal | As Advertised |
| MXFP4 | CUTLASS/AO | 🟡 Suboptimal | Slower than expected on compute bound shapes / needs tuning |
| NVFP4 | Cublas/ScaledMM | 🔴 Not Benchmarked | Need to wire up in TorchAO |
Casting Kernels
| Format | Kernel Source | Performance Status | Notes |
|---|---|---|---|
| MXFP | AO/Triton | 🟡 Pretty good | Dim 0 Cast is optimal same as scale swizzle but can be fused |
| MXFP | Inductor | 🟡 SubOptimal | Falling back to eager needs fixes |
Tasks
1. Kernel Optimization
- [x] Implement custom swizzle kernel for feed forward networks See: https://github.com/pytorch/ao/pull/2168, currently working around Inductor
- [ ] Investigate why inductor is falling back for swizzle kernel Started but need to land a fix PR: https://github.com/pytorch/pytorch/issues/153194 https://github.com/pytorch/pytorch/pull/154006
- [ ] Optimize MXFP4 kernel in AO which isn't performing as expected Very vanilla cutlass template we need to identify the shapes we care about and likely instantiate a few more templates
- [ ] Implement scale caching for static weight tensors (post tensor parallelism) Once we load the quantized model and shared the weights for TP, the mx scales can be pre-swizzled. We should create a mechanism for caching these, will show some speed up
- [ ] Develop single kernel for MX cast + swizzled scale generation These exist https://github.com/pytorch/ao/issues/2217 in triton, Ideal end state is to have inductor produce this for us.
2. VLLM Integration
- [ ] Add Option for NVfp4 quant scheme
- [ ] Debug inconsistent behavior in VLLM integration
- [ ] Optimize TTFT (Time To First Token) for MXFP8 format
- [ ] Ensure consistent throughput across different model sizes
- [ ] Profile memory bandwidth utilization for different formats
- [ ] Compare latency patterns across BF16, MXFP8, and MXFP4
Traces:
Decode 8b (batch 1 gemmv): BF16: https://fburl.com/rj5cpto2 MXFP8: https://fburl.com/xo985lyg
In this case Inductor is producing the fully unrolled gemmv: https://www.internalfb.com/intern/paste/P1817637421/ w/ no tensor-cores, I wonder what we need to do to support this cc @eellison
Performance Data
Performance by Format (Qwen2-7B-Instruct)
| Format | Throughput (req/s) | Total Token Throughput (tok/s) | Output Token Throughput (tok/s) |
|---|---|---|---|
| BF16 | 56.68 | 24053.96 | 11590.20 |
| MXFP8 | 50.52 | 21443.10 | 10332.18 |
| MXFP4 | 56.64 | 24039.96 | 11583.46 |
Performance by Format (Qwen2.5-72B)
| Format | Throughput (req/s) | Total Token Throughput (tok/s) | Output Token Throughput (tok/s) |
|---|---|---|---|
| BF16 | 26.28 | 11154.41 | 5374.66 |
| MXFP4 | 25.96 | 11018.18 | 5309.02 |
MXFP8 Serving Benchmark
============ Serving Benchmark Result ============
Successful requests: 1024
Benchmark duration (s): 13.43
Total input tokens: 225502
Total generated tokens: 185297
Request throughput (req/s): 76.26
Output token throughput (tok/s): 13800.30
Total Token throughput (tok/s): 30594.93
---------------Time to First Token----------------
Mean TTFT (ms): 1119.68
Median TTFT (ms): 1100.86
P99 TTFT (ms): 1721.80
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 29.11
Median TPOT (ms): 27.07
P99 TPOT (ms): 46.91
---------------Inter-token Latency----------------
Mean ITL (ms): 23.38
Median ITL (ms): 17.28
P99 ITL (ms): 49.45
==================================================
Numerics
Some quick LM evals:
References
- LLama 70B feed forward with custom kernel: fburl.com/125yv8hh
- LLama 70B feed forward without custom kernel: fburl.com/a21gwjmc
- BF16 reference: fburl.com/2lgn9xkx
- Non-Quantized Trace: fburl.com/sput3bmn
- Quantized Trace: fburl.com/0pgmyrge
- BF16 70B MLP: fburl.com/aeqm5s4v
- MXFP8 70B MLP: fburl.com/uxgoju4r
- MXFP4 70B MLP: fburl.com/u95f6f39
- Eager vs. Inductor Swizzle Profile: fburl.com/kqhm91ib