MX single node performance tracker
This issue tracks single node performance of MX training and inference: fast gemm, fast fused kernels. If this issue is complete, we can train on single node (8 GPUs) at SOTA performance with MXFP8, and do inference TBD with MXFP8 and MXFP4.
training performance summary
As of 2025-03-27
- e2e pretraining speedup vs bf16 + compile on LLaMa 3 8B, 8 B200 GPUs, torchtitan with default settings
- 🟢 float8 tensorwise: 1.19x
- :red_circle: mxfp8: 1.09x (should be similar to tensorwise's 1.19x once we fix all the issues). Right now scaling/casting to mx is slow
- 🟢 gemm speedup: cuBLAS mxfp8 gemm is 2x to 3x faster than bf16 - done for now
- 🟢 mx casting to dim0 with torch.compile achieves up to 67% of peak mem bw - done for now
- :red_circle: mx casting to dim1 is our main performance gap
- :red_circle: torch.compile achieves only up to 9% peak mem bw, tracking here: https://github.com/pytorch/pytorch/issues/149982
- 🟡 handwritten triton kernel from https://github.com/pytorch/ao/pull/1932 achieves up to 44% peak mem bw, can likely be improved further
- :black_square_button: mx casting to dim0 + dim1 at the same time is postponed for now until we make the individual dim0 and dim1 kernels better
invididual components
system overview (for training)
# There are three gemms in a forward + backward of a Linear layer:
#
# 1. input @ weight_t = output (forward pass)
# 2. grad_output @ weight = grad_input (backward pass)
# 3. input_t @ grad_output = grad_weight (backward pass)
#
# in Python pseudocode, we want the following (for mxfp8):
# forward pass
# inputs are in high precision
x_hp, w_hp = ...
# input @ weight_t = output
x_mx_dim0, x_scale_dim0 = to_mx(x_hp, dim=0)
w_mx_dim0, w_scale_dim0 = to_mx(w_hp, dim=0)
y = mx_gemm(x_mx_dim0, w_mx_dim0.t(), x_scale_dim0, w_scale_dim1)
# backward pass
# inputs are in high precision
x_hp, w_hp, go_hp = ...
# grad_output @ weight = grad_input
go_mx_dim0, go_scale_dim0 = to_mx(go_hp, dim=0)
w_mx_dim1, w_scale_dim1 = to_mx(w_hp.t().contiguous(), dim=0)
gi = mx_gemm(go_mx_dim0, w_mx_dim1.t(), go_scale_dim0, w_scale_dim1)
# input_t @ grad_output = grad_weight
go_mx_dim1, go_scale_dim1 = to_mx(go_hp.t().contiguous().t(), dim=0)
x_mx_dim1, x_scale_dim1 = to_mx(x_hp.t().contiguous(), dim=0)
gw = mx_gemm(go_mx_dim1, x_mx_dim1.t(), go_scale_dim1, x_scale_dim1)
We want:
- the mx gemm to be fast
- the cast from high precision to mx (
to_mxin pseudocode above) to be fast - the cast from high precision to mx to be fused to preceding/subsequent ops where possible
gemm kernel
Expected peak TFLOPs on NVIDIA B200, without sparsity: 2.25 petaFLOPs for b16, 4.25 petaFLOPs for fp8/fp6 (2x from bf16), 9.0 petaFLOPs for fp4 (4x from bf16) (source: https://resources.nvidia.com/en-us-blackwell-architecture, pages 19-20)
| kernel | wrapper | current TFLOPs | peak TFLOPs | notes |
|---|---|---|---|---|
| mxfp8 cuBLAS | torch._scaled_mm |
TBD | 4.25 petaFLOPs | landed, https://github.com/pytorch/pytorch/pull/147548 |
| mxfp8 CUTLASS | torchao.ops.mx_fp8_bf16 |
TBD | 4.25 petaFLOPs | landed, https://github.com/pytorch/ao/pull/1637 |
| mxfp4 CUTLASS | torchao.ops.mx_fp4_bf16 |
TBD | 9.0 petaFLOPs | landed, https://github.com/pytorch/ao/pull/1661 |
| nvfp4 cuBLAS | torch._scaled_mm |
TBD | 9.0 petaFLOPs | in progress, https://github.com/pytorch/pytorch/pull/148792 |
Once we have machines where benchmarking is possible, we should add easily reproducible gemm benchmarks and fill out the TFLOP column in the table above.
scaling/casting kernels
Our current plan is to use torch.compile, same as we are doing with float8.
- we should ensure we can generate a single fused kernel for scaling and casting a tensor to mxfp8. Today, torch.compile generates two kernels: https://github.com/pytorch/ao/issues/1769
- once we have a single fused kernel, we should make sure it's bandwidth bound. As of 2025-02-24, the casting to MX code is numerically correct but researchy and has not been optimized for performance. TODO issue.
- the
float8_e8m0fnudtype was added to PyTorch in https://github.com/pytorch/pytorch/pull/147466, we need to updatetorchaoto use this dtype for scales, and then ensure that PT2 works e2e. TODO issue - we need to ensure torch.compile is good at generating good fused kernels for the custom scale packing layout required by B200s. https://github.com/pytorch/ao/issues/1773
- we should ensure the cast across dim0 and dim1 is performant: https://github.com/pytorch/ao/issues/1788
- given an MXLinear (fwd + bwd), we should expect at most six scale+cast kernels: two for each of
input,weight,grad_output. The kernels forinputandgrad_outputshould be fused with preceding/subsequent ops as appropriate. TODO issue.
e2e training performance
From https://resources.nvidia.com/en-us-blackwell-architecture pages 19-20, on B200 the single GPU memory bandwidth we expect is 8 TB/s, the fp8/fp6 tensor core peak FLOPS is 4.5 petaFLOPS (without sparsity), and the fp4 tensor core peak FLOPS is 9.0 petaFLOPS (without sparsity).
- we need a roofline of mx scaling/casting to get the shapes which are expected to see speedups, and we should have a benchmark to compared observed to theoretical
- [blocked] eventually we should get to SOTA performance in torchtitan. Currently, this work is blocked by general issues with Blackwell support in PyTorch, such as NCCL not working. Tracking is here: https://github.com/pytorch/pytorch/issues/145949
e2e inference performance
- need an inference roofline
- need to decide where to benchmark
For the to_mxfp8_dim1_kernel, the main performance blocker is the shared memory bank conflicts, which arises from the transpose followed by the store:
col_normalized = tl.trans(col_normalized_t)
col_normalized = col_normalized.to(tl.float8e4nv)
tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask)
The quantized mxfp8 data will be first stored to shared memory. Then it will be loaded from shared memory (using SASS LDS.U8), which is vectorized, and then stored back to global memory. During loading data from shared memory, a lot of bank conflicts happen as seen in the NCU profile.
The above was generated using:
export USE_IR_LOC=ttgir
ncu --set full --clock-control none --import-source on -f -o report python benchmarks/mx_formats/cast_bench.py --mode dim1_mx_triton --M 16384 --K 16384
FP8 vs MXFP8 Benchmark Comparison
References
- MXFP8: https://fburl.com/s2g726a1
CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.compile --training.steps 150 --model.converters mx --mx.recipe_name "mxfp8" --profiling.enable_profiling
step: 70 loss: 6.9682 memory: 35.94GiB(20.15%) tps: 12,798 tflops: 741.17 mfu: 16.47%**
- FP8 Pertensor: https://fburl.com/99cbavtm
CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.compile --training.steps 150 --model.converters float8 --profiling.enable_profiling
step: 70 loss: 7.0653 memory: 35.92GiB(20.14%) tps: 13,674 tflops: 791.90 mfu: 17.60%
- BF16: https://fburl.com/29px2raz
CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.compile --training.steps 150 --profiling.enable_profilingstep: 70 loss: 6.9436 memory: 36.00GiB(20.18%) tps: 11,388 tflops: 659.54 mfu: 14.66%
QKV Operations (including Output Projection)
| Operation | MXFP8 (μs) | FP8 (μs) | Performance |
|---|---|---|---|
| First FP8 GEMM | 152 | 144 | FP8 is 5% faster |
| Second FP8 GEMM | 52 | 61 | MXFP8 is 15% faster |
| Third FP8 GEMM | 51 | 53 | Noise (negligible difference) |
| Output Projection | 123 | 107 | FP8 is ~13% faster |
MLP Operations
| Test Run | MXFP8 (μs) | FP8 (μs) |
|---|---|---|
| mat 1 | 403 | 377 |
| mat 2 | 383 | 376 |
| mat 3 | 400 | 364 |
Average Performance:
- FP8: ~395 μs
- MXFP8: ~372 μs
- MXFP8 is ~6% faster on average
TLDR is forward has some attribution of speed difference from gemm perf differences
Macro Level
Fwd for MXFP8: 2.2 ms Fwd for FP8 PerTensor: 1.33 ms
FWD FP8 is 0.6x the latency
Bwd for MXFP8: is about 12.120 ms Bwd for FP8 PerTensor: 9.8 ms
BWD FP8 is 0.8x the latency