vllm
vllm copied to clipboard
add causal-conv1d in Triton and integrate into vLLM with test code
This PR adds Triton-based causal-conv1d, making Mamba-based models in vLLM
- fully Triton-only backend.
- one step closer to be compatible with vLLM v1 design, i.e. without splitting the batch into prefill-only and decode-only for CUDA-split processing.
There are two kernels implemented
-
causal_conv1d_update_triton: which outperforms the corresponding CUDA kernel in handling decode-only requests
[data benchmarking two kernels runtime by increasing the number of decode-only requests in a batch]
-
causal_conv1d_fn_triton: which outperform CUDA kernel in batch of mixed prefill/decode requests, e.g. 27x faster in the below microbenchmark with the same batch of mixed prefill/decode requests.
It also performs better than the CUDA-split pathway which was merged as PR #17146.
[data benchmarking runtime processing the same batch of mixed requests, first send the batch to the single Triton kernel, and then using CUDA-split pathway where requests are first separated, with prefill-only requests are sent to one kernel, and decode-only requests are sent to the second kernel]
ALGORITHMIC CHOICE OF TRITON KERNEL: Unlike CUDA kernel which is implemented with parallelism in 2D, i.e. along feature-dimension, and batch size; Triton kernel is implemented with parallelism in 3D, i.e. along also sequence-dimension. Also, the Triton kernels don't make any changes to the layout of the input data which is contiguous along the feature-dimension. Another key difference is that Triton kernels expect the conv-state to be contiguous along the feature-dimension, while in existing CUDA implementation, it expects the conv-state cache to be contiguous along the kernel-width (i.e. sequence-length) axis. Nevertheless, the two CUDA kernels are not compatible with the layout of conv-state cache, and therefore prevents the efficient processing in decode-only requests or mixed prefill/decode-requests.
Also, some other improvement in reducing overhead is incorporated. Even though binary code generated from Triton is faster, the launch overhead is a known issue and is therefore need further optimization to get the E2E Triton-only Mamba models in vLLM performant. Here, we also incorporate such improvements by using a metadata that can be reused across layers.
In our benchmark on ShareGPT dataset which has short input prompt (a few hundreds of tokens)
-
default setting: generates short number of tokens (i.e. 256 tokens) CUDA-backed Bamba (
ibm-ai-platform/Bamba-9B) is still faster; 10% slower (total token throughput) yet only 2% in output token throughput and 2% in TTFT. -
generating 1024 tokens: Triton-backed Bamba is now faster with 5% faster on token throughput; and with 11% faster on TTFT. The benefit of faster Triton kernels now exceeds the overall costs of Triton launch overhead.
In the longer context length and/or longer number of generated tokens, Triton-only Mamba-based model is expected to be better than CUDA-split approach. However, the PR maintains the existing CUDA pathway as the default one until it is adopted by vLLM maintainers. Currently, the code is added as an optional pathway to the CUDA-split via VLLM_USE_TRITON_CONV1D environment variable set to 1.
This is one step closer to be compatible with vLLM v1 design, i.e. without splitting the batch into prefill-only and decode-only for CUDA-split processing.
Test code is also added
👋 Hi! Thank you for contributing to the vLLM project.
💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.
🚀
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @thoangtrvn.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
Could you report some lm_eval scores running gsm8k, as well as make sure it runs correctly without
--enforce-eager?
GSM8K RESULT
#ibm-ai-platform/Bamba-9B
# (current) CUDA-SPLIT code
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.2335|± |0.0117|
| | |strict-match | 5|exact_match|↑ |0.3442|± |0.0131|
# PR code
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.2456|± |0.0119|
| | |strict-match | 5|exact_match|↑ |0.3495|± |0.0131|
#Zyphra/Zamba2-2.7B
# (current) CUDA-SPLIT code
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.5330|± |0.0137|
| | |strict-match | 5|exact_match|↑ |0.5466|± |0.0137|
# PR code
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.5330|± |0.0137|
| | |strict-match | 5|exact_match|↑ |0.5466|± |0.0137|
#mistralai/Mamba-Codestral-7B-v0.1
# (current) CUDA-SPLIT code
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.4647|± |0.0137|
| | |strict-match | 5|exact_match|↑ |0.4549|± |0.0137|
# PR code
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.4655|± |0.0137|
| | |strict-match | 5|exact_match|↑ |0.4526|± |0.0137|
COMMAND TO RUN:
echo 'ibm-ai-platform/Bamba-9B'
lm_eval --model vllm --model_args pretrained=ibm-ai-platform/Bamba-9B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --cache_requests true --tasks gsm8k
echo "DONE RUN (CUDA-SPLIT)"
export VLLM_USE_TRITON_CONV1D="1"
lm_eval --model vllm --model_args pretrained=ibm-ai-platform/Bamba-9B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --cache_requests true --tasks gsm8k
echo "DONE RUN (PR)"
echo 'Zyphra/Zamba2-2.7B'
lm_eval --model vllm --model_args pretrained=Zyphra/Zamba2-2.7B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code --cache_requests true --tasks gsm8k
echo "DONE RUN (CUDA-SPLIT)"
export VLLM_USE_TRITON_CONV1D="1"
lm_eval --model vllm --model_args pretrained=Zyphra/Zamba2-2.7B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code --cache_requests true --tasks gsm8k
echo "DONE RUN (PR)"
echo 'Mamba-Codestral-7B-v0.1'
lm_eval --model vllm --model_args pretrained=mistralai/Mamba-Codestral-7B-v0.1,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code --cache_requests true --tasks gsm8k
echo "DONE RUN (CUDA-SPLIT)"
export VLLM_USE_TRITON_CONV1D="1"
lm_eval --model vllm --model_args pretrained=mistralai/Mamba-Codestral-7B-v0.1,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code --cache_requests true --tasks gsm8k
echo "DONE RUN (PR)"
default setting: generates short number of tokens (i.e. 256 tokens) CUDA-backed Bamba (ibm-ai-platform/Bamba-9B) is still faster; 10% slower (total token throughput) yet only 2% in output token throughput and 2% in TTFT.
The table shown shows a drop from 3628 to 3557 in total tok/s, which is only a 2% slowdown. Am I missing something?
For simplicity, it would be better to remove the causal_conv1d CUDA kernel and use the triton kernel in this PR exclusively if the slowdowns aren't significant. If it's a 2% drop in the worst case I think this may be a reasonable thing to do.
@tlrmchlsmth : I merged the recent changes in main, and adopt the contribution from PR #19327. Please let me know if there is something else I should revise as well. To remove completely mamba2_metadata.py and use only mamba_attn.py, I guess it's better to be in a separate PR.
basic-correctness-test errors are related
basic-correctness-test errors are related
fixed that circular import issue.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @thoangtrvn.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@tlrmchlsmth : after merging main , the CI/CD fails at two checks, which seems to be about timing out. Can you help to check if there is anything that I should do?
@thoangtrvn I start seeing this error coming out of the CPU benchmark https://github.com/pytorch/pytorch-integration-testing/actions/runs/16216384228/job/45786696314?pr=44#step:15:3350. Any thoughts?
@huydhn : Not sure why it reports tl.int32 is not supported. I am checking now.
File "/opt/venv/lib/python3.12/site-packages/vllm/model_executor/layers/mamba/ops/causal_conv1d.py", line 30, in <module>
batch: tl.int32, # actually padded_batch
^^^^^^^^
AttributeError: module 'triton.language' has no attribute 'int32'. Did you mean: 'int64'?
@huydhn : Please try again when this PR is merged https://github.com/vllm-project/vllm/pull/20838
@huydhn : Please try again when this PR is merged #20838
Thank you! It's fixed now