sglang icon indicating copy to clipboard operation
sglang copied to clipboard

[moe] fix: correct the cache size in the last chunk

Open ch-wan opened this issue 10 months ago • 0 comments

Motivation

This code confronts AssertionError:

import torch
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe

N = 64 * 1024 + 10
E = 8
H = 1024
I = 4096

x = torch.randn((N, H), device="cuda", dtype=torch.float16)
w1 = torch.randn((E, I * 2, H), device="cuda", dtype=torch.float16)
w2 = torch.randn((E, H, I), device="cuda", dtype=torch.float16)

gating_output = torch.randn((N, E), device="cuda", dtype=torch.float16)
topk = 2

fused_moe(x, w1, w2, gating_output, topk, True)

In the original implementation, intermediate_cache2 holds the flattened activations computed across all experts for all tokens. As demonstrated in this code snippet, its first dimension is defined as topk_ids.shape[1] times that of intermediate_cache1. However, the final chunk’s size was incorrectly set, which could lead to misalignments.

Modifications

This PR fixes the issue by correctly multiplying the size of intermediate_cache2 in the last chunk by topk_ids.shape[1].

This is my first PR. The reviewers may help add the above test script to the unit tests.

Checklist

  • [ ] Format your code according to the Code Formatting with Pre-Commit.
  • [ ] Add unit tests as outlined in the Running Unit Tests.
  • [ ] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
  • [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
  • [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
  • [ ] Please feel free to join our Slack channel at https://slack.sglang.ai to discuss your PR.

ch-wan avatar Feb 18 '25 23:02 ch-wan