torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

NVFP4 MoE Training Status

Open syed-ahmed opened this issue 1 month ago • 7 comments

Keep centralized tracking of NVFP4 training for DeepSeek-V3 and LLAMA4 model in torchtitan. We will keep this status updated.

Recipes

  • [ ] Current pretraining recipe (https://arxiv.org/pdf/2509.25149)
Image

Kernels

  • [ ] NVFP4 GEMM

    • [x] Currently supported through cuBLAS in torch.nn.functional.scaled_mm
    • [ ] A triton or cute dsl kernel will enable composability with features like Symmetric Memory. CuTe DSL kernel is available in current release.
  • [ ] NVFP4 Grouped GEMM

    • [x] ~Currently lacking integration into PyTorch.~ An attempt was made before but not satisfactory: https://github.com/pytorch/pytorch/pull/156806
      • [x] NVFP4 grouped gemm (via. torch.nn.functional.scaled_grouped_mm) - https://github.com/pytorch/pytorch/pull/166308
    • [ ] CuTe DSL kernel is available in current release. A triton or cute dsl kernel will enable composability with features like Symmetric Memory.
    • [x] BF16 GroupedGemm CuTe DSL integration into inductor: https://github.com/pytorch/pytorch/pull/165036/, ​​https://github.com/pytorch/pytorch/issues/165785
  • [ ] NVFP4 GEMM/Grouped GEMM variants for blackwell ultra

    • [ ] CuTe DSL kernel will be available in a later release.
    • [ ] cuBLAS support planned for future CUDA release.
  • [ ] Random Hadamard Transform

    • [ ] Currently lacking native implementation. Implementation available in TE but we need to consider composability and maintenance.
  • [ ] Quantize with Stochastic Rounding

    • [ ] Currently lacking native implementation. Implementation available in TE but we need to consider composability and maintenance

Execution Plan

  • [ ] Need RFC for CuTe DSL NVFP4 GEMM/Grouped GEMM in torch.nn.functional.scaled_mm in PyTorch: https://github.com/pytorch/pytorch/issues/166611
  • [ ] TorchAO Execution Plan: https://github.com/pytorch/ao/issues/3293

Test Plan

  • [ ] Functionality
    • [ ] E2E Convergence runs
      • [ ] LLAMA 4
      • [ ] DeepSeek-V3
  • [ ] Performance
    • [ ] Microbenchmarks for cuBLAS vs CuTe DSL GEMM/Grouped Gemm kernels
    • [ ] E2E Performance benchmarks
      • [ ] LLAMA 4
      • [ ] DeepSeek-V3

syed-ahmed avatar Oct 29 '25 05:10 syed-ahmed

CC: @slayton58 @ngimel @supriyar @Priyadlfw @ptrblck @eqy

Please feel free to add anything missing or suggest updates.

syed-ahmed avatar Oct 29 '25 05:10 syed-ahmed

@syed-ahmed

  • NVFP4 grouped gemm (via. torch.nn.functional.scaled_grouped_mm) was merged yesterday - https://github.com/pytorch/pytorch/pull/166308
  • NVFP4 gemm support is via torch.nn.functional.scaled_mm, not torch._scaled_mm

NVFP4 GEMM/Grouped GEMM variants for blackwell ultra ... Supported by both cuBLAS and CuTe DSL

Where is the grouped NVFP4 support in cublas? I'm not finding it through cursory looks through docs / examples / google

slayton58 avatar Oct 29 '25 12:10 slayton58

Thanks @slayton58 . My bad, cuBLAS plans to support Grouped NVFP4 GEMM in future CUDA release. I've updated the text.

syed-ahmed avatar Oct 29 '25 15:10 syed-ahmed

cc @danielvegamyhre @vkuzo

supriyar avatar Oct 29 '25 17:10 supriyar

Also very curious where are CuTEDSL nvfp4 kernels available, the ones in https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_blockscaled_gemm.py#L2857 synchronize like there's no tomorrow (in the linked line). Also not sure if it supports dynamic K case needed for weight gradient

ngimel avatar Oct 29 '25 23:10 ngimel

Added this issue for NVFP4 CuteDSL specific discussion: https://github.com/pytorch/pytorch/issues/166611. The nvfp4 gemm is linked there. And that is indeed the grouped scaled mm example. Let's followup on that thread what's missing in these examples and we can improve them.

syed-ahmed avatar Oct 30 '25 04:10 syed-ahmed

Thanks for adding this @syed-ahmed, in torchao we plan to work on nvfp4 training for MoEs (i.e., dynamically quantization + nvfp4 grouped gemms for routed experts) and integrate into torchtitan in the coming months. This should be able to re-use some existing infra from mxfp8 moe training (tensor subclassing, composability with parallelisms etc) but we'll need to add some new quantization kernels, as well as hadmard transform, stochastic rounding, etc.

I'll scope things out in more detail and create an issue to organize the specific work that needs to be done in torchao, and link it here for visibility/coordination. Just FYI we are focused on finishing up some final improvements for mxfp8 moe training and validating at scale, but nvfp4 is the next top priority after that.

danielvegamyhre avatar Oct 30 '25 22:10 danielvegamyhre