ao icon indicating copy to clipboard operation
ao copied to clipboard

MX single node performance tracker

Open vkuzo opened this issue 10 months ago • 2 comments

This issue tracks single node performance of MX training and inference: fast gemm, fast fused kernels. If this issue is complete, we can train on single node (8 GPUs) at SOTA performance with MXFP8, and do inference TBD with MXFP8 and MXFP4.

training performance summary

As of 2025-03-27

  • e2e pretraining speedup vs bf16 + compile on LLaMa 3 8B, 8 B200 GPUs, torchtitan with default settings
    • 🟢 float8 tensorwise: 1.19x
    • :red_circle: mxfp8: 1.09x (should be similar to tensorwise's 1.19x once we fix all the issues). Right now scaling/casting to mx is slow
  • 🟢 gemm speedup: cuBLAS mxfp8 gemm is 2x to 3x faster than bf16 - done for now
  • 🟢 mx casting to dim0 with torch.compile achieves up to 67% of peak mem bw - done for now
  • :red_circle: mx casting to dim1 is our main performance gap
    • :red_circle: torch.compile achieves only up to 9% peak mem bw, tracking here: https://github.com/pytorch/pytorch/issues/149982
    • 🟡 handwritten triton kernel from https://github.com/pytorch/ao/pull/1932 achieves up to 44% peak mem bw, can likely be improved further
  • :black_square_button: mx casting to dim0 + dim1 at the same time is postponed for now until we make the individual dim0 and dim1 kernels better

invididual components

system overview (for training)

# There are three gemms in a forward + backward of a Linear layer:
#
# 1.       input @ weight_t    = output     (forward pass)
# 2. grad_output @ weight      = grad_input (backward pass)
# 3.     input_t @ grad_output = grad_weight (backward pass)
# 
# in Python pseudocode, we want the following (for mxfp8):

# forward pass

# inputs are in high precision
x_hp, w_hp = ...

# input @ weight_t = output
x_mx_dim0, x_scale_dim0 = to_mx(x_hp, dim=0)
w_mx_dim0, w_scale_dim0 = to_mx(w_hp, dim=0)
y = mx_gemm(x_mx_dim0, w_mx_dim0.t(), x_scale_dim0, w_scale_dim1)

# backward pass

# inputs are in high precision
x_hp, w_hp, go_hp = ...

# grad_output @ weight = grad_input
go_mx_dim0, go_scale_dim0 = to_mx(go_hp, dim=0)
w_mx_dim1, w_scale_dim1 = to_mx(w_hp.t().contiguous(), dim=0)
gi = mx_gemm(go_mx_dim0, w_mx_dim1.t(), go_scale_dim0, w_scale_dim1)

# input_t @ grad_output = grad_weight
go_mx_dim1, go_scale_dim1 = to_mx(go_hp.t().contiguous().t(), dim=0)
x_mx_dim1, x_scale_dim1 = to_mx(x_hp.t().contiguous(), dim=0)
gw = mx_gemm(go_mx_dim1, x_mx_dim1.t(), go_scale_dim1, x_scale_dim1)

We want:

  1. the mx gemm to be fast
  2. the cast from high precision to mx (to_mx in pseudocode above) to be fast
  3. the cast from high precision to mx to be fused to preceding/subsequent ops where possible

gemm kernel

Expected peak TFLOPs on NVIDIA B200, without sparsity: 2.25 petaFLOPs for b16, 4.25 petaFLOPs for fp8/fp6 (2x from bf16), 9.0 petaFLOPs for fp4 (4x from bf16) (source: https://resources.nvidia.com/en-us-blackwell-architecture, pages 19-20)

kernel wrapper current TFLOPs peak TFLOPs notes
mxfp8 cuBLAS torch._scaled_mm TBD 4.25 petaFLOPs landed, https://github.com/pytorch/pytorch/pull/147548
mxfp8 CUTLASS torchao.ops.mx_fp8_bf16 TBD 4.25 petaFLOPs landed, https://github.com/pytorch/ao/pull/1637
mxfp4 CUTLASS torchao.ops.mx_fp4_bf16 TBD 9.0 petaFLOPs landed, https://github.com/pytorch/ao/pull/1661
nvfp4 cuBLAS torch._scaled_mm TBD 9.0 petaFLOPs in progress, https://github.com/pytorch/pytorch/pull/148792

Once we have machines where benchmarking is possible, we should add easily reproducible gemm benchmarks and fill out the TFLOP column in the table above.

scaling/casting kernels

Our current plan is to use torch.compile, same as we are doing with float8.

  • we should ensure we can generate a single fused kernel for scaling and casting a tensor to mxfp8. Today, torch.compile generates two kernels: https://github.com/pytorch/ao/issues/1769
  • once we have a single fused kernel, we should make sure it's bandwidth bound. As of 2025-02-24, the casting to MX code is numerically correct but researchy and has not been optimized for performance. TODO issue.
  • the float8_e8m0fnu dtype was added to PyTorch in https://github.com/pytorch/pytorch/pull/147466, we need to update torchao to use this dtype for scales, and then ensure that PT2 works e2e. TODO issue
  • we need to ensure torch.compile is good at generating good fused kernels for the custom scale packing layout required by B200s. https://github.com/pytorch/ao/issues/1773
  • we should ensure the cast across dim0 and dim1 is performant: https://github.com/pytorch/ao/issues/1788
  • given an MXLinear (fwd + bwd), we should expect at most six scale+cast kernels: two for each of input, weight, grad_output. The kernels for input and grad_output should be fused with preceding/subsequent ops as appropriate. TODO issue.

e2e training performance

From https://resources.nvidia.com/en-us-blackwell-architecture pages 19-20, on B200 the single GPU memory bandwidth we expect is 8 TB/s, the fp8/fp6 tensor core peak FLOPS is 4.5 petaFLOPS (without sparsity), and the fp4 tensor core peak FLOPS is 9.0 petaFLOPS (without sparsity).

  • we need a roofline of mx scaling/casting to get the shapes which are expected to see speedups, and we should have a benchmark to compared observed to theoretical
  • [blocked] eventually we should get to SOTA performance in torchtitan. Currently, this work is blocked by general issues with Blackwell support in PyTorch, such as NCCL not working. Tracking is here: https://github.com/pytorch/pytorch/issues/145949

e2e inference performance

  • need an inference roofline
  • need to decide where to benchmark

vkuzo avatar Feb 24 '25 15:02 vkuzo

For the to_mxfp8_dim1_kernel, the main performance blocker is the shared memory bank conflicts, which arises from the transpose followed by the store:

col_normalized = tl.trans(col_normalized_t)
col_normalized = col_normalized.to(tl.float8e4nv)
tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask)

The quantized mxfp8 data will be first stored to shared memory. Then it will be loaded from shared memory (using SASS LDS.U8), which is vectorized, and then stored back to global memory. During loading data from shared memory, a lot of bank conflicts happen as seen in the NCU profile.

Image

The above was generated using:

export USE_IR_LOC=ttgir
ncu --set full --clock-control none --import-source on -f -o report python benchmarks/mx_formats/cast_bench.py --mode dim1_mx_triton --M 16384 --K 16384

syed-ahmed avatar May 08 '25 17:05 syed-ahmed

FP8 vs MXFP8 Benchmark Comparison

References

  • MXFP8: https://fburl.com/s2g726a1 CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.compile --training.steps 150 --model.converters mx --mx.recipe_name "mxfp8" --profiling.enable_profiling

step: 70 loss: 6.9682 memory: 35.94GiB(20.15%) tps: 12,798 tflops: 741.17 mfu: 16.47%**

  • FP8 Pertensor: https://fburl.com/99cbavtm
  • CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.compile --training.steps 150 --model.converters float8 --profiling.enable_profiling

step: 70 loss: 7.0653 memory: 35.92GiB(20.14%) tps: 13,674 tflops: 791.90 mfu: 17.60%

  • BF16: https://fburl.com/29px2raz
  • CONFIG_FILE="torchtitan/models/llama3/train_configs/llama3_8b.toml " ./run_train.sh --model.print_after_conversion --training.compile --training.steps 150 --profiling.enable_profiling
  • step: 70 loss: 6.9436 memory: 36.00GiB(20.18%) tps: 11,388 tflops: 659.54 mfu: 14.66%

QKV Operations (including Output Projection)

Operation MXFP8 (μs) FP8 (μs) Performance
First FP8 GEMM 152 144 FP8 is 5% faster
Second FP8 GEMM 52 61 MXFP8 is 15% faster
Third FP8 GEMM 51 53 Noise (negligible difference)
Output Projection 123 107 FP8 is ~13% faster

MLP Operations

Test Run MXFP8 (μs) FP8 (μs)
mat 1 403 377
mat 2 383 376
mat 3 400 364

Average Performance:

  • FP8: ~395 μs
  • MXFP8: ~372 μs
  • MXFP8 is ~6% faster on average

TLDR is forward has some attribution of speed difference from gemm perf differences

Macro Level

Fwd for MXFP8: 2.2 ms Fwd for FP8 PerTensor: 1.33 ms

FWD FP8 is 0.6x the latency

Bwd for MXFP8: is about 12.120 ms Bwd for FP8 PerTensor: 9.8 ms

BWD FP8 is 0.8x the latency

drisspg avatar May 22 '25 21:05 drisspg