sglang
sglang copied to clipboard
[moe] fix: correct the cache size in the last chunk
Motivation
This code confronts AssertionError:
import torch
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
N = 64 * 1024 + 10
E = 8
H = 1024
I = 4096
x = torch.randn((N, H), device="cuda", dtype=torch.float16)
w1 = torch.randn((E, I * 2, H), device="cuda", dtype=torch.float16)
w2 = torch.randn((E, H, I), device="cuda", dtype=torch.float16)
gating_output = torch.randn((N, E), device="cuda", dtype=torch.float16)
topk = 2
fused_moe(x, w1, w2, gating_output, topk, True)
In the original implementation, intermediate_cache2 holds the flattened activations computed across all experts for all tokens. As demonstrated in this code snippet, its first dimension is defined as topk_ids.shape[1] times that of intermediate_cache1. However, the final chunk’s size was incorrectly set, which could lead to misalignments.
Modifications
This PR fixes the issue by correctly multiplying the size of intermediate_cache2 in the last chunk by topk_ids.shape[1].
This is my first PR. The reviewers may help add the above test script to the unit tests.
Checklist
- [ ] Format your code according to the Code Formatting with Pre-Commit.
- [ ] Add unit tests as outlined in the Running Unit Tests.
- [ ] Update documentation / docstrings / example tutorials as needed, according to Writing Documentation.
- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to Benchmark and Profiling and Accuracy Results.
- [ ] For reviewers: If you haven't made any contributions to this PR and are only assisting with merging the main branch, please remove yourself as a co-author when merging the PR.
- [ ] Please feel free to join our Slack channel at https://slack.sglang.ai to discuss your PR.