feat: Optionally split MoE inputs into chunks to reduce GPU memory usage
If max_num_tokens is large and attention DP is enabled on a relatively large number of GPUs, the MoE workspace size will be very large and thus OOM occurs.
This MR allows to optionally split MoE inputs into chunks to reduce GPU memory usage. To enable this feature, moe_max_num_tokens needs to be set in pytorch_backend_config. By doing this, at most moe_max_num_tokens tokens will be sent to torch.ops.trtllm.fused_moe at the same time. If the number of tokens exceeds moe_max_num_tokens, the input tensors will be split into chunks and a for loop will be used.
To achieve better performance, an extra CUDA stream is used to allow the overlapping of computation and communication between adjacent chunks.
/bot run
PR_Github #585 [ run ] triggered by Bot
PR_Github #585 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #498 completed with status: 'FAILURE'
/bot run
PR_Github #599 [ run ] triggered by Bot
PR_Github #599 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #508 completed with status: 'SUCCESS'
@jinyangyuan-nvidia Thanks for adding this feature, Jinyang.
I noticed that currently this MR only apple the "optionally split MoE inputs into chunks" for DS R1 model. How much additional efforts may be needed if we want to make this feature also applicable to other MoE models?
cc @hlu1 @QiJune for vis.
Thanks June
Thanks June. This feature can be easily applied to other MoE models by refactoring the code. I will improve this PR accordingly.
Thanks June. This feature can be easily applied to other MoE models by refactoring the code. I will improve this PR accordingly.
Thanks, Jinyang!
June
/bot run --add-multi-gpu-test
PR_Github #694 [ run ] triggered by Bot
PR_Github #694 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #581 completed with status: 'FAILURE'
/bot run --add-multi-gpu-test
PR_Github #697 [ run ] triggered by Bot
PR_Github #697 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #584 completed with status: 'FAILURE'
/bot run
PR_Github #699 [ run ] triggered by Bot
PR_Github #699 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #586 completed with status: 'FAILURE'
This LGTM, thanks @jinyangyuan-nvidia.
One bigger change would be, with this sort of approach, I think it would be good to consider if we could couple this with the DP num tokens per rank and treat each amount as a distinct microbatch. That way we could fully decouple each microbatch so we don't need to synchronize at the end of the MOE layer and can just immediately start processing on the relevant DP rank. Ideally, this would just naturally fill the pipeline so the only exposed latency is in the last layer.
My only concern is we could end up overlapping attention and MOE and slowing both down or introducing new bubbles
Thanks @djns99. This is a really great idea worth looking into. I had intended to reduce TTFT when attention DP is enabled and concurrency is low in the next week by changing all-gather into a series of broadcast wrapped in a NCCL group (in this way, there is no need to pad each rank to have the same number of tokens). I think your idea might further improve the performance.
/bot run
PR_Github #838 [ run ] triggered by Bot
PR_Github #838 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #676 completed with status: 'FAILURE'
/bot run
PR_Github #847 [ run ] triggered by Bot
/bot kill
PR_Github #853 [ kill ] triggered by Bot
PR_Github #847 [ run ] completed with state ABORTED
PR_Github #853 [ kill ] completed with state SUCCESS
Successfully killed previous jobs for commit 7015d2d
/bot run