vllm icon indicating copy to clipboard operation
vllm copied to clipboard

[MoE] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked

Open wenscarl opened this issue 2 months ago • 9 comments

Add grouped_gemm_nt_masked from flashinfer to support nvfp4 MoE.

depends on silu_and_mul nvfp4 quanization fusion rework

Purpose

Test Plan

VLLM_WORKER_MULTIPROC_METHOD="spawn" \
VLLM_ALL2ALL_BACKEND="masked_gemm" \
VLLM_USE_STANDALONE_COMPILE=0 \
VLLM_USE_FLASHINFER_MOE_FP4=1 \
VLLM_FLASHINFER_MOE_BACKEND="cutedsl" \
lm_eval --model vllm --model_args pretrained=/dev/shm/checkpoints/nvidia-DeepSeek-R1-0528-FP4,quantization=modelopt_fp4,data_parallel_size=8,enable_expert_parallel=False,tensor_parallel_size=1,max_model_len=2048 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

Test Result

vllm (pretrained=/dev/shm/checkpoints/nvidia-DeepSeek-R1-0528-FP4,quantization=modelopt_fp4,data_parallel_size=8,enable_expert_parallel=True,tensor_parallel_size=1,max_model_len=2048,enforce_eager=True,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9591|±  |0.0055|
|     |       |strict-match    |     5|exact_match|↑  |0.9538|±  |0.0058|

Essential Elements of an Effective PR Description Checklist
  • [ ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • [ ] The test plan, such as providing test command.
  • [ ] The test results, such as pasting the results comparison before and after, or e2e results
  • [ ] (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • [ ] (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

wenscarl avatar Sep 30 '25 21:09 wenscarl

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Oct 06 '25 21:10 mergify[bot]

Thanks for working on this ! I think this will also help enable gpt-oss + DeepEPLowLatency on blackwell 🙌

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Oct 14 '25 03:10 mergify[bot]

@mgoin https://github.com/flashinfer-ai/flashinfer/pull/1927 is merged. Should unblock this PR.

wenscarl avatar Oct 23 '25 12:10 wenscarl

Okay, we still need to wait for the next flashinfer release right? I still see 0.4.1 as the latest

mgoin avatar Oct 23 '25 18:10 mgoin

Okay, we still need to wait for the next flashinfer release right? I still see 0.4.1 as the latest

Ping. A new version of flashinfer is released.

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Nov 11 '25 16:11 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Nov 12 '25 22:11 mergify[bot]

This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @wenscarl.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

mergify[bot] avatar Nov 14 '25 18:11 mergify[bot]

@wenscarl When I run the test locally, I see a failure for the last case, PTAL

tests/kernels/moe/test_cutedsl_moe.py .......F                                                                                                                                                         [100%]

================================================================================================== FAILURES ==================================================================================================
_________________________________________________________________________________ test_grouped_gemm_nt_masked[16-128-512-5] __________________________________________________________________________________

bs = 16, hidden_dim = 128, inter_dim = 512, topk = 5

    @pytest.mark.parametrize(
        "bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)]
    )
    @torch.inference_mode()
    def test_grouped_gemm_nt_masked(
        bs: int, hidden_dim: int, inter_dim: int, topk: int
    ) -> None:
        torch.manual_seed(42)
        B = bs
        D = hidden_dim
        N = inter_dim
        # CuteDSL group gemm has issue when not all experts are active.
        # i.e. masked = [2, 3, 0, 0, 1] where the 2nd and 3rd experts are inactive
        # see https://github.com/flashinfer-ai/flashinfer/issues/1856
        num_experts = bs
        hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda")
        weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda")
        router_logits = torch.randn(B, num_experts, dtype=torch.float32)
    
        hidden_states_expanded = (
            hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
        )
        hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs(
            hidden_states_expanded, router_logits, num_experts, topk
        )
    
        a_amax = (
            hidden_states_3d.abs()
            .amax(dim=(1, 2))
            .to(torch.float32)
            .to(hidden_states.device)
        )
        b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device)
        a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
        b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
        out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked(
            hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m
        )
        # reference
        out_ref = grouped_gemm_ref(
            hidden_states_expanded=hidden_states_expanded,
            hidden_states_3d=hidden_states_3d,
            weights=weights,
            topk_idx=topk_idx,
            masked_m=masked_m,
            B=B,
            topk=topk,
            num_experts=num_experts,
        )
        # Note: just to compare the masked position due to cutedsl may write nan
        # into unmasked position.
        for i in range(num_experts):
>           torch.testing.assert_close(
                out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]],
                out_ref.to(out_flashinfer.device)[i, : masked_m[i]],
                atol=1e-1,
                rtol=1e-1,
            )
E           AssertionError: Tensor-likes are not close!
E           
E           Mismatched elements: 1529 / 1536 (99.5%)
E           Greatest absolute difference: 42.5 at index (1, 212) (up to 0.1 allowed)
E           Greatest relative difference: 1.0 at index (0, 0) (up to 0.1 allowed)

tests/kernels/moe/test_cutedsl_moe.py:570: AssertionError

It's because the global scaling factors have nan. Fixed by filling 1s at initialization.

wenscarl avatar Nov 18 '25 20:11 wenscarl