[V1][Core] Support offloading KV cache to CPU.
TL;DR
In V1, swap GPU KV cache blocks to CPU upon eviction and swap them back if there's a cache hit.
Swap Strategy
CPU → GPU swap-in happens naturally when requests hit the cache (unless we do prefetching).
GPU → CPU swap-out can be handled in two ways:
- Eagerly: Immediately after a request completes and its blocks are freed.
- Lazily: When evicting a GPU block while scheduling new requests.
This PR adopts (2) to minimize unnecessary swaps. However, the downside is that the swap-out overhead might be exposed.
Ideally, an optimal approach would asynchronously offload X cache blocks at a certain cadence (e.g., hidden behind the main CUDA graph) while maintaining free GPU block headroom. This would add complexity and is left for future work.
Implementation
This PR builds on the excellent V1 KV cache manager, blend in with the existing interface.
Newly introduced metadata states:
cpu_block_poolandcached_block_hash_to_cpu_blockmirror their GPU counterparts.
High-Level Flow:
- The KV cache manager accumulates swap-in/out decisions during each scheduling cycle.
- These swap decisions are then "flushed" to the scheduler output, allowing model runners to issue aggregated swap calls before model execution, minimizing dispatch overhead.
For simplicity, we avoid threading the scheduler output through multiple KV cache manager calls. Instead, swap-related data is accumulated in step_* fields (e.g., step_h2d_swap_map).
A new end_schedule_step callback resets them at the end of each scheduling iteration. (Open to alternative designs.)
CPU Cache Eviction Policy
We currently adopt a simple round-robin strategy to do CPU cache eviction. LRU will be added in a followup PR.
User Configuration:
We reuse the existing --swap-space flag (previously unused in V1) to control the number of CPU blocks.
Whether to change the default (currently 4GB) remains up for discussion.
Benchmark Results
Performance depends on workload characteristics. We present two benchmark scenarios:
- Topline: High cache hits, aiming to measure speedup.
- Bottomline: Zero CPU cache hits (only GPU → CPU swaps), measuring swap-out overhead.
Topline Performance
# Baseline (No CPU Offloading): **14.2673 secs**
VLLM_USE_V1=1 python benchmark_long_document_qa_throughput.py --model meta-llama/Llama-3.1-8B --enable-prefix-caching --num-documents 10 --repeat-count 2 --repeat-mode tile --swap-space 0 --num-gpu-blocks-override 1252 --max-model-len 20010 --max-num-batched-tokens 20010
# With CPU Offloading: **8.0952 secs**
VLLM_USE_V1=1 python benchmark_long_document_qa_throughput.py --model meta-llama/Llama-3.1-8B --enable-prefix-caching --num-documents 10 --repeat-count 2 --repeat-mode tile --swap-space 40 --num-gpu-blocks-override 1252 --max-model-len 20010 --max-num-batched-tokens 20010
Bottomline Performance
# Baseline (No CPU Offloading): **14.2608 secs**
VLLM_USE_V1=1 python benchmark_long_document_qa_throughput.py --model meta-llama/Llama-3.1-8B --enable-prefix-caching --num-documents 10 --repeat-count 2 --repeat-mode tile --swap-space 0 --num-gpu-blocks-override 1252 --max-model-len 20010 --max-num-batched-tokens 20010
# With CPU Offloading: **14.2857 secs**
VLLM_USE_V1=1 python benchmark_long_document_qa_throughput.py --model meta-llama/Llama-3.1-8B --enable-prefix-caching --num-documents 10 --repeat-count 2 --repeat-mode tile --swap-space 4 --num-gpu-blocks-override 1252 --max-model-len 20010 --max-num-batched-tokens 20010
TODO
- [ ] write tests
- [ ] more benchmarks and profiling
- [ ] docs
👋 Hi! Thank you for contributing to the vLLM project.
💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.
🚀
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @mengzhu28.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
Is V0 Support?
Hi @mengzhu28, thanks for submitting the great PR! I will reach out to you offline.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @mengzhu28.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@mengzhu28 Could you please rebase the PR?
@WoosukKwon as discussed offline, created RFC #16144
Would it be better to abstract the CPU offloading related functions into a new class and add a parameter to enable it?
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!
helo,can this support 1-cpu/n-gpu in one host situation?
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!
This pull request has been automatically closed due to inactivity. Please feel free to reopen if you intend to continue working on it. Thank you!