vllm
vllm copied to clipboard
v1: Add Request.block_hashes
This PR move the request block hashes from the KVCacheManager to the Request object itself. In particular, this will allow connectors to access the request block hashes.
[!WARNING] You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!
👋 Hi! Thank you for contributing to the vLLM project.
💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.
Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.
To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.
🚀
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @orozery.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @orozery.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @orozery.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
My major concern is that the
block_hashlist is updated herehttps://github.com/vllm-project/vllm/blob/d3f05c9248d79dc900c79d090db16cc2e5d96ee3/vllm/v1/core/block_pool.py#L177
It is quite far from
Requestclass. It will make the number of element in the list magic to connectors, and hurt the modularity of the code. Therefore, if you want to put it intoRequestclass, can you move this logic from "update when block hash is needed" to "update when token ids of the request are updated"?
Today, block_hashes are computed when the scheduler calls _update_waiting_for_remote_kv.
Shortly after this call, calls are made to kv_cache_manager.get_computed_blocks and connector.get_num_new_matched_tokens.
The connector must get the block hashes at this point.
I'm trying to interpret your suggestion "update when token ids of the request are updated".
My best guess is that you mean request.append_output_token_ids. This is too late for the connector.
I'm guessing that I do not understand you.
Can you please be more concrete and reference to actual code locations?
Yes I mean request.append_output_token_ids.
Based on the code here, the new tokens are updated in update_from_output in step k, and then used by the connector in scheduler.schedule() of step k+1. Why is it too late?
https://github.com/vllm-project/vllm/blob/72d14d0eed4b29e5827519283c085a7a674f3256/vllm/v1/engine/core.py#L223-L240
Yes I mean
request.append_output_token_ids. Based on the code here, the new tokens are updated inupdate_from_outputin step k, and then used by the connector inscheduler.schedule()of stepk+1. Why is it too late?https://github.com/vllm-project/vllm/blob/72d14d0eed4b29e5827519283c085a7a674f3256/vllm/v1/engine/core.py#L223-L240
@heheda12345 Consider the following scenario:
- Engine gets a fresh new prompt with 1000 (number does not matter) tokens.
- Scheduler pops the prompt from its
self.waitinglist. - Scheduler checks whether this prompt has already computed tokens in the GPU prefix cache, by calling
self.kv_cache_manager.get_computed_blocks. Assumes no hits are found, sonum_computed_tokens = 0. - Next, it queries the connector of any externally hit tokens, by calling
self.connector.get_num_new_matched_tokens(request, ...). At this point, the connector needs the block hashes in order to check for hits. - Assuming your approach of delaying the setting of
request.block_hashes, the connector will yieldnum_external_computed_tokens = 0, even though it may have all the tokens available! - The request tokens will be sent to the workers, and the workers will recompute the KV-values for this 1000 tokens.
- Only after the workers re-compute the tokens, the scheduler will call
update_from_output, setting the block_hashes.
So to sum-up, because of the delay in setting of request.block_hashes, we lose the ability to utilize offloaded KV-values of tokens from external source.
But I think the hash of the prompt can be initialized when the request is created
But I think the hash of the prompt can be initialized when the request is created
@heheda12345 The block hashes are computed inside BlockPool.cache_full_blocks.
The Request object is created at EngineCore.add_request.
Are you suggesting that we compute the hashes inside EngineCore.add_request ?
Yes you can compute the hash in both add_request and append_output_token_ids. Though I prefer the solution in https://github.com/vllm-project/vllm/pull/15652 for hashing prompt tokens , I'm OK with computing hash in add_request in this PR and move it to frontend later.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @orozery.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
Yes you can compute the hash in both
add_requestandappend_output_token_ids.
@heheda12345 I made the changes as you suggested.
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @orozery.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
This pull request has merge conflicts that must be resolved before it can be merged. Please rebase the PR, @orozery.
https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork
@heheda12345 @njhill Moved all hashing to be triggered by Request directly.
Can you add more comments in
hash_request_tokens
You mean in get_request_block_hasher
I want to say move some comments in hash_request_tokens and cache_full_blocks to your new code.
@heheda12345 I made another set of changes. Thanks for taking the time giving out really helpful comments!
@orozery could you rebase on latest main? Should hopefully help with at least some of the CI failures.
@orozery could you rebase on latest main? Should hopefully help with at least some of the CI failures.
done
Can you rebase with main again? It should fix the test_eagle_correctness failure.