vllm
vllm copied to clipboard
[Kernel] Support merge attn cuda kernel
Motivation
This PR supports the CUDA kernel for merge_attn_states, it will mitigate the CPU overhead of current triton kernel.
Performance benefits
There is ~5% perf gain (81.383 us vs 77.554 us) compared with current triton merge kernel on my local RTX 3080.
baseline:
(vllmbuild) wenqin@wenqin-System-Product-Name:~/study/vllm$ python benchmarks/kernels/benchmark_cascade_attention.py --version="triton"
INFO 04-04 21:19:56 [__init__.py:239] Automatically detected platform cuda.
Namespace(version='triton', num_tokens=1000, num_query_heads=64, num_kv_heads=8, head_size=128, dtype='half', seed=0, profile=False)
Warming up...
Kernel running time: 81.383 us
opt:
(vllmbuild) wenqin@wenqin-System-Product-Name:~/study/vllm$ python benchmarks/kernels/benchmark_cascade_attention.py --version="cuda"
INFO 04-04 21:20:08 [__init__.py:239] Automatically detected platform cuda.
Namespace(version='cuda', num_tokens=1000, num_query_heads=64, num_kv_heads=8, head_size=128, dtype='half', seed=0, profile=False)
Warming up...
Kernel running time: 77.554 us
Heuristic
This perf data is based on the num_tokens=1000, I tried to set num_tokens=10000, it seems the CUDA kernel will bring ~5% regression.
I saw the SASS code, there are some redundant computation insts on the CUDA kernel compared with the triton kernel, I guess the root cause of the regression is the redundant inst.
WDYT of this regression? Shall we dispatch to the triton kernel when the num_tokens is greater than a specific threshold?
👋 Hi! Thank you for contributing to the vLLM project.
💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.
🚀
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @ywq880611.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork