vllm
vllm copied to clipboard
[V1] Support cross-layer KV sharing
Motivation
Some models like Tencent-Hunyuan-Large (#10043) and Hymba-1.5B-Base (#10783) use cross-layer KV sharing (e.g. Cross-Layer Attention). This PR adds the ability for KV caches to be shared between attention layers.
Testing
Sanity Check
As a sanity check that the implementation is working, I made all layers after the 18th layer in Qwen/Qwen3-8B (36 layers total) and printed out the id() of the kv cache used in attention forward:
model.layers.0.self_attn.attn => 139678446053136
model.layers.1.self_attn.attn => 139678446059136
…
model.layers.15.self_attn.attn => 139678446045456
model.layers.16.self_attn.attn => 139678446055056
model.layers.17.self_attn.attn => 139678446050736
model.layers.18.self_attn.attn => 139678446050736
model.layers.19.self_attn.attn => 139678446050736
…
model.layers.32.self_attn.attn => 139678446050736
model.layers.33.self_attn.attn => 139678446050736
model.layers.34.self_attn.attn => 139678446050736
model.layers.35.self_attn.attn => 139678446050736
As expected, layers 19 to 36 are re-using the KV cache allocated by layer 18.
Unit Tests
All newly added unit tests pass:
pytest tests/v1/worker/test_gpu_model_runner.py -k "test_init_kv_cache"
Evals
checked the score of gsm8k before and after my PR on Qwen/Qwen3-8B:
lm_eval --model vllm --tasks gsm8k --model_args pretrained=Qwen/Qwen3-8B,tensor_parallel_size=1 --batch_size auto
before PR:
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8795|± |0.0090|
| | |strict-match | 5|exact_match|↑ |0.8734|± |0.0092|
After PR:
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8802|± |0.0089|
| | |strict-match | 5|exact_match|↑ |0.8734|± |0.0092|
also cc: @heheda12345
👋 Hi! Thank you for contributing to the vLLM project.
💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.
🚀
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @sarckk.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
entrypoints test failure is unrelated and failing on trunk (see https://buildkite.com/vllm/fastcheck/builds/24385)
@heheda12345 could you take a look?
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @sarckk.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@heheda12345 thanks for taking a look. To answer your questions:
-
I prefer 'KV sharing' simply because it seems to be the academic term for this kind of thing (e.g. see https://arxiv.org/abs/2410.14442), whereas 'KV reuse' seems to be used to refer to something else (e.g. prefix caching, https://developer.nvidia.com/blog/introducing-new-kv-cache-reuse-optimizations-in-nvidia-tensorrt-llm/)
-
One model with kv sharing should use less memory per block than another model with the same model config but without kv sharing.
I didn't quite understand why it would be "less memory per block". I think we'll just have less physical KV blocks being used? Here is where the core memory savings would be coming from, by not allocating if there is a target layer for KV sharing. I might be missing some other implementation details here, let's chat offline?
-
Is KV sharing compatible with kv connectors now?
Not at the moment, I believe
-
To mimic it, we can only return layers with kv_sharing_target_layer_idx is None
I explored this design but I remember the complexity was just offloaded to a later stage as we needed to handle KV allocation for layers without a KV cache spec anyways. But I think the APIs around KV cache groups have changed considerably since then, let me take a look again.
-
Yes, this is a good point. Some models explicitly keep track of FQN for each layer so it shouldn't be difficult. I'll make this change.
-
Yes, I will add this check.
@heheda12345 updated the PR with your feedback. could you take a look?
overview of changes:
- removed all references to "reuse" and unified to using the term "sharing"
- standardize on using layer FQN (
kv_sharing_target_layer_name) instead of layer index (kv_sharing_target_layer_idx) to avoid ambiguities. - as suggested, make KV sharing implicit by only returning layers without in
get_kv_cache_specmethod of model runner (GPU and TPU). - with this design, the logic of "We have less physical memory per KV block, thus we can increase num_gpu_blocks" is handled implicitly below, added comments to explain why with cross-layer KV sharing we can allocate more GPU blocks https://github.com/vllm-project/vllm/blob/4ca72c70504c4f7018c15fcacece0308d84436ac/vllm/v1/core/kv_cache_utils.py#L624-L626
- Added target layer validation and V1-only support check at the
Attentionlayer level.
I think the design is overall cleaner than before, thanks for the feedback. To answer your question,
What is the blocker for making it compatible with KV connector?
I looked into this a bit more. I think the current design is actually compatible with KV connector but we need to keep maybe_save_kv_layer_to_connector calls in each attention forward even for layers that don't allocate KV cache, because the KV connector isn't aware of shared caches. For example, for the simple SharedStorageConnector that writes/saves to/from disk, it will save the KV cache for each layer in separate safetensor files, so we need to save each KV layer. LMCache's connector also seemed OK to handle kv_caches with shared memory pointers. I'm not sure if other connectors are compatible with cross-layer KV sharing, though. For now I think we can allow it to be used together, but let me know if any concerns.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @sarckk.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
updated to address comments.
For KV connector, can you at least try https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/disaggregated-prefill-v1 with a local model?
tried and it still works
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @sarckk.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@heheda12345 addressed comments. could you take another look?