lmdeploy
lmdeploy copied to clipboard
Torch engine prefix caching
Enable by set shared_cache=True.
Hi @grimoire @lvhan028 Why did you choose the radix tree implementation? Have you considered using the hash table implementation? What factors did you consider, such as scalability or performance? Thanks.
Any detail about the hash table implementation? Honestly, I do not like my radix tree implementation in this PR.
Any detail about the hash table implementation? Honestly, I do not like my radix tree implementation in this PR.
@ispobock may follow up. Currently researched the implementations of vLLM, RTP-LLM, and SGLang
@grimoire We compared the prefix cache implementation for other projects:
-
- Hash Table
- compute hash key for each block:
hash(prefix tokens, tokens in this block) - block level reuse, if seq1:
xxxxyyyy, seq2:xxxxzzzz, seq3xxxxyyyyzzzz, each block contains 4 tokens, then seq2 can reuse the first block of seq1, seq3 can reuse 2 blocks of seq1 - now only support prefix cache (
xxxxxoooo), but plan to support general cache (xxxoooxxxooo) in the future - maybe need to consider hash collision
- Complexity:
- Assume
Nis the number of seq,Lis the length of seq - Time (Find & Insert):
O(N*(L^2)), because compute hash key needsO(L^2), mentioned here - Space:
O(N*L)
- Assume
-
- Hash Table
- compute hash key for each seq:
hash(tokens in sequence) - block level reuse, like vllm
- Complexity:
- Time
- Find:
O((N^2)*L), due to token level match - Insert:
O(N*L)
- Find:
- Space:
O(N*L)
- Time
-
- Radix Tree
- can only support prefix cache (
xxxxxoooo), cannot support general cache (xxxoooxxxooo) - Complexity:
- Time (Find & Insert):
O(N*L) - Space: worst
O(N*L), if no shared part
- Time (Find & Insert):
When do we need general cache?
@ispobock Do they support window attention? How do they evict blocks? Would it take a long time if we have a large amount of blocks?
s-lora would increase number of blocks(by use a small block size) and window attention would make the block eviction more complex. I failed to find a good solution.
@ispobock Do they support window attention? How do they evict blocks? Would it take a long time if we have a large amount of blocks?
s-lora would increase number of blocks(by use a small block size) and window attention would make the block eviction more complex. I failed to find a good solution.
In mistralai-sf24/hackathon, sliding window has been removed https://x.com/mistralailabs/status/1771670765521281370
And I think this approach is acceptable for now. https://github.com/InternLM/lmdeploy/blob/137d106a0cba4f5a0297ac07959256c57435433b/lmdeploy/pytorch/config.py#L63-L66
@grimoire
When do we need general cache?
For example seq1: xxxxyyyyzzzz, seq2: yyyyzzzz, 4 tokens per block, for general cache, seq2 may use the last 2 cached blocks of seq1.
It's mentioned in vllm's design, but I'm not sure the real usage and implementation.
How do they evict blocks? Would it take a long time if we have a large amount of blocks?
It seems all of them are using reference count + LRU for evict policy.
And I think this approach is acceptable for now.
https://github.com/InternLM/lmdeploy/blob/137d106a0cba4f5a0297ac07959256c57435433b/lmdeploy/pytorch/config.py#L63-L66
ref https://github.com/vllm-project/vllm/pull/2762/files#r1495331586
Sure, let's ignore the sliding window for now.
It seems that the hash map does not bring much benefits to prefix matching. Eviction by blocks takes more time than eviction by node(sort by visit time, update ref-count/visit-time, update sequence status...).
But adding new concept node into the schedule made the code error prone and hard to maintain.
Any advice?
vllm didn't take the radix tree implementation due to the hard maintenance:
Major benefits of this design over a KV block Trie
- Sometimes, caching is not limited to prefix caching:
- With Mistral's sliding window attention, we only need to cache the last tokens in the sliding window.
- With attention sinks, we need to cache the first few tokens and the latest tokens.
- Maintaining hash table is simpler than maintaining a tree.
- Extensible to more advanced caching policy (the one above is just an example).
In sglang, actually there is no block concept because the size of each page is equivalent to one token, which simplified the implementation.
For example seq1:
xxxxyyyyzzzz, seq2:yyyyzzzz, 4 tokens per block, for general cache, seq2 may use the last 2 cached blocks of seq1.
In this case
- The positional embedding used for
yyyyzzzzis offsetted by 4 steps (instead of starting from 0) xxxxwhich is involved in the computation ofxxxxyyyyzzzzis ignored.
The result will be different from computing yyyyzzzz directly. The outcome maybe similar but you have no guarantee on it.
vllm didn't take the radix tree implementation due to the hard maintenance:
Major benefits of this design over a KV block Trie
Sometimes, caching is not limited to prefix caching:
- With Mistral's sliding window attention, we only need to cache the last tokens in the sliding window.
- With attention sinks, we need to cache the first few tokens and the latest tokens.
Maintaining hash table is simpler than maintaining a tree.
Extensible to more advanced caching policy (the one above is just an example).
In sglang, actually there is no
blockconcept because the size of each page is equivalent to one token, which simplified the implementation.
Hi @grimoire Do you have any suggestions?
Maintaining hash table is simpler than maintaining a tree.
That's true, especially when block size is not 1. In this PR, node is a wrap of sequence with meta info. I want to share the same block manage code to ease the implementation, but it ... sucks.
I want to try the block-based strategy. Guess it would take a long time to design and prototype since I don't want to break any features that already exist.
Hi @grimoire I would like to know, is the completion of this PR currently ready for normal use? Thanks.
@zhyncs Yes, this is not a draft.
ref https://github.com/InternLM/lmdeploy/issues/1407#issuecomment-2044203407
@grimoire We compared the prefix cache implementation for other projects:
Hash Table
compute hash key for each block:
hash(prefix tokens, tokens in this block)block level reuse, if seq1:
xxxxyyyy, seq2:xxxxzzzz, seq3xxxxyyyyzzzz, each block contains 4 tokens, then seq2 can reuse the first block of seq1, seq3 can reuse 2 blocks of seq1now only support prefix cache (
xxxxxoooo), but plan to support general cache (xxxoooxxxooo) in the futuremaybe need to consider hash collision
Complexity:
- Assume
Nis the number of seq,Lis the length of seq- Time (Find & Insert):
O(N*(L^2)), because compute hash key needsO(L^2), mentioned here- Space:
O(N*L)
Hash Table
compute hash key for each seq:
hash(tokens in sequence)block level reuse, like vllm
Complexity:
Time
- Find:
O((N^2)*L), due to token level match- Insert:
O(N*L)Space:
O(N*L)
Radix Tree
can only support prefix cache (
xxxxxoooo), cannot support general cache (xxxoooxxxooo)Complexity:
- Time (Find & Insert):
O(N*L)- Space: worst
O(N*L), if no shared part
After https://github.com/sgl-project/sglang/pull/364, SGLang Radix Tree implementation RPS increased by nearly 10%
Very good discussion here. ref https://github.com/vllm-project/vllm/issues/2614#issuecomment-2116330884