vllm
vllm copied to clipboard
[MoE] Nvfp4 Masked Gemm: Add flashinfer grouped_gemm_nt_masked
Add grouped_gemm_nt_masked from flashinfer to support nvfp4 MoE.
depends on silu_and_mul nvfp4 quanization fusion rework
Purpose
Test Plan
VLLM_WORKER_MULTIPROC_METHOD="spawn" \
VLLM_ALL2ALL_BACKEND="masked_gemm" \
VLLM_USE_STANDALONE_COMPILE=0 \
VLLM_USE_FLASHINFER_MOE_FP4=1 \
VLLM_FLASHINFER_MOE_BACKEND="cutedsl" \
lm_eval --model vllm --model_args pretrained=/dev/shm/checkpoints/nvidia-DeepSeek-R1-0528-FP4,quantization=modelopt_fp4,data_parallel_size=8,enable_expert_parallel=False,tensor_parallel_size=1,max_model_len=2048 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
Test Result
vllm (pretrained=/dev/shm/checkpoints/nvidia-DeepSeek-R1-0528-FP4,quantization=modelopt_fp4,data_parallel_size=8,enable_expert_parallel=True,tensor_parallel_size=1,max_model_len=2048,enforce_eager=True,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9591|± |0.0055|
| | |strict-match | 5|exact_match|↑ |0.9538|± |0.0058|
Essential Elements of an Effective PR Description Checklist
- [ ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
- [ ] The test plan, such as providing test command.
- [ ] The test results, such as pasting the results comparison before and after, or e2e results
- [ ] (Optional) The necessary documentation update, such as updating
supported_models.mdandexamplesfor a new model. - [ ] (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @wenscarl.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
Thanks for working on this ! I think this will also help enable gpt-oss + DeepEPLowLatency on blackwell 🙌
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @wenscarl.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@mgoin https://github.com/flashinfer-ai/flashinfer/pull/1927 is merged. Should unblock this PR.
Okay, we still need to wait for the next flashinfer release right? I still see 0.4.1 as the latest
Okay, we still need to wait for the next flashinfer release right? I still see 0.4.1 as the latest
Ping. A new version of flashinfer is released.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @wenscarl.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @wenscarl.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @wenscarl.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@wenscarl When I run the test locally, I see a failure for the last case, PTAL
tests/kernels/moe/test_cutedsl_moe.py .......F [100%] ================================================================================================== FAILURES ================================================================================================== _________________________________________________________________________________ test_grouped_gemm_nt_masked[16-128-512-5] __________________________________________________________________________________ bs = 16, hidden_dim = 128, inter_dim = 512, topk = 5 @pytest.mark.parametrize( "bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)] ) @torch.inference_mode() def test_grouped_gemm_nt_masked( bs: int, hidden_dim: int, inter_dim: int, topk: int ) -> None: torch.manual_seed(42) B = bs D = hidden_dim N = inter_dim # CuteDSL group gemm has issue when not all experts are active. # i.e. masked = [2, 3, 0, 0, 1] where the 2nd and 3rd experts are inactive # see https://github.com/flashinfer-ai/flashinfer/issues/1856 num_experts = bs hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda") weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda") router_logits = torch.randn(B, num_experts, dtype=torch.float32) hidden_states_expanded = ( hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) ) hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs( hidden_states_expanded, router_logits, num_experts, topk ) a_amax = ( hidden_states_3d.abs() .amax(dim=(1, 2)) .to(torch.float32) .to(hidden_states.device) ) b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device) a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked( hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m ) # reference out_ref = grouped_gemm_ref( hidden_states_expanded=hidden_states_expanded, hidden_states_3d=hidden_states_3d, weights=weights, topk_idx=topk_idx, masked_m=masked_m, B=B, topk=topk, num_experts=num_experts, ) # Note: just to compare the masked position due to cutedsl may write nan # into unmasked position. for i in range(num_experts): > torch.testing.assert_close( out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]], out_ref.to(out_flashinfer.device)[i, : masked_m[i]], atol=1e-1, rtol=1e-1, ) E AssertionError: Tensor-likes are not close! E E Mismatched elements: 1529 / 1536 (99.5%) E Greatest absolute difference: 42.5 at index (1, 212) (up to 0.1 allowed) E Greatest relative difference: 1.0 at index (0, 0) (up to 0.1 allowed) tests/kernels/moe/test_cutedsl_moe.py:570: AssertionError
It's because the global scaling factors have nan. Fixed by filling 1s at initialization.