vllm
vllm copied to clipboard
[V1] LoRA - Add triton kernels for V1
Add shrink and expand triton kernels for V1.
Why do we need a new set of kernels:
- V0 sorts/groups requests based on LoRA ID. The SGMV kernels take advantage of this and groups the compute within thread blocks.
- V1 doesn't group requests based on LoRA ID. The new set of kernels have information about which input tokens map to which LoRA ID and they use this information to load the appropriate input tokens. The rest of the matmul is very similar to the SGMV kernels.
Kernel Code Change: The new kernels re-use a lot of the code from the existing SGMV kernels. The main changes are,
- Kernel Launch Grid formulation (this was required so the kernels are CUDAGraph compatible. Note that SGMV kernels are not)
- Loading of the input tokens (A matrix) for the matmul.
All other kernel code is the same as the existing SGMV kernels. I refactored the code so it can be reused.
benchmark serving numbers :
server command : VLLM_USE_V1="1" vllm serve meta-llama/Llama-2-7b-hf --max-loras 4 --max-lora-rank 8 --enable-lora --lora-modules lora1=yard1/llama-2-7b-sql-lora-test lora2=yard1/llama-2-7b-sql-lora-test lora3=yard1/llama-2-7b-sql-lora-test lora4=yard1/llama-2-7b-sql-lora-test --no-enable-prefix-caching
benchmark command : python3 benchmarks/benchmark_serving.py --model meta-llama/Llama-2-7b-hf --dataset-name sharegpt --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 500 --request-rate inf --lora-modules lora1 lora2 lora3 lora4
V1 LoRA - This PR:
============ Serving Benchmark Result ============
Successful requests: 500
Benchmark duration (s): 68.81
Total input tokens: 117316
Total generated tokens: 110942
Request throughput (req/s): 7.27
Output token throughput (tok/s): 1612.20
Total Token throughput (tok/s): 3317.02
---------------Time to First Token----------------
Mean TTFT (ms): 7103.84
Median TTFT (ms): 7136.82
P99 TTFT (ms): 14040.12
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 126.93
Median TPOT (ms): 94.96
P99 TPOT (ms): 231.31
---------------Inter-token Latency----------------
Mean ITL (ms): 76.47
Median ITL (ms): 58.02
P99 ITL (ms): 236.28
==================================================
V1 LoRA - Main:
============ Serving Benchmark Result ============
Successful requests: 500
Benchmark duration (s): 117.84
Total input tokens: 117316
Total generated tokens: 110942
Request throughput (req/s): 4.24
Output token throughput (tok/s): 941.44
Total Token throughput (tok/s): 1936.96
---------------Time to First Token----------------
Mean TTFT (ms): 10277.45
Median TTFT (ms): 9370.56
P99 TTFT (ms): 22882.99
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms): 259.87
Median TPOT (ms): 236.82
P99 TPOT (ms): 445.65
---------------Inter-token Latency----------------
Mean ITL (ms): 173.39
Median ITL (ms): 164.76
P99 ITL (ms): 459.73
==================================================
Kernel micro benchmark:
Please find kernel microbenchmark here - https://docs.google.com/spreadsheets/d/1b_8KsDGdiSGWlHODMszug_-do7OlSPlSzoV84VkmVPc/edit?usp=sharing (sheet : "V1 : Dont Sort Tokens By LoRA ")
Note : The V0 SGMV and BGMV kernels are not tuned. But the V1 kernels are tuned with triton auto-tuner. Therefore the discrepancy between the V1 and SGMV/BGMV kernels could be partially explained by the tuning. The SGMV kernel depends heavily on the input being sorted. V1 kernels aren't affected as much.
👋 Hi! Thank you for contributing to the vLLM project.
💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.
🚀
All the LoRA tests have failed again
All the LoRA tests have failed again
Looking into this now 👍
Update : I enabled tests in tests/lora/test_layers.py for V1. The tests work locally but OOM's on the CI - I am tracking this down.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @varun-sundar-rabindranath.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
It seems these modifications have significantly increased the time consumption for lora testing
It seems these modifications have significantly increased the time consumption for lora testing
Yes. This PR adds the v1_kernel tests in test_punica_ops.py and enables test_layers.py to run for V1 also. I believe most of it is coming from the test_layers.py that now runs for both V1 and V0 (effectively doubling its run time) - Ill see what we can do here.
[Edit] @jeejeelee
Update : Reduced the tests in commits https://github.com/vllm-project/vllm/pull/13096/commits/a18d273557b2bfebde85f1fc41699e64c87259a5 and https://github.com/vllm-project/vllm/pull/13096/commits/ba94947b83d597ba4091385b4feaa78b01c83a0d
The times are now,
a maximum of 7 minute increase. Do you think we should prune further ?
