lmdeploy icon indicating copy to clipboard operation
lmdeploy copied to clipboard

Torch engine prefix caching

Open grimoire opened this issue 1 year ago • 19 comments

Enable by set shared_cache=True.

grimoire avatar Apr 04 '24 04:04 grimoire

Hi @grimoire @lvhan028 Why did you choose the radix tree implementation? Have you considered using the hash table implementation? What factors did you consider, such as scalability or performance? Thanks.

zhyncs avatar Apr 07 '24 06:04 zhyncs

Any detail about the hash table implementation? Honestly, I do not like my radix tree implementation in this PR.

grimoire avatar Apr 07 '24 06:04 grimoire

Any detail about the hash table implementation? Honestly, I do not like my radix tree implementation in this PR.

@ispobock may follow up. Currently researched the implementations of vLLM, RTP-LLM, and SGLang

zhyncs avatar Apr 07 '24 06:04 zhyncs

@grimoire We compared the prefix cache implementation for other projects:

  • vllm

    • Hash Table
    • compute hash key for each block: hash(prefix tokens, tokens in this block)
    • block level reuse, if seq1: xxxxyyyy, seq2: xxxxzzzz, seq3 xxxxyyyyzzzz, each block contains 4 tokens, then seq2 can reuse the first block of seq1, seq3 can reuse 2 blocks of seq1
    • now only support prefix cache ( xxxxxoooo ), but plan to support general cache (xxxoooxxxooo) in the future
    • maybe need to consider hash collision
    • Complexity:
      • AssumeN is the number of seq,L is the length of seq
      • Time (Find & Insert): O(N*(L^2)), because compute hash key needs O(L^2), mentioned here
      • Space: O(N*L)
  • rtp-llm

    • Hash Table
    • compute hash key for each seq: hash(tokens in sequence)
    • block level reuse, like vllm
    • Complexity:
      • Time
        • Find: O((N^2)*L), due to token level match
        • Insert: O(N*L)
      • Space: O(N*L)
  • sglang

    • Radix Tree
    • can only support prefix cache ( xxxxxoooo ), cannot support general cache (xxxoooxxxooo)
    • Complexity:
      • Time (Find & Insert): O(N*L)
      • Space: worst O(N*L), if no shared part

ispobock avatar Apr 07 '24 07:04 ispobock

When do we need general cache?

grimoire avatar Apr 07 '24 08:04 grimoire

@ispobock Do they support window attention? How do they evict blocks? Would it take a long time if we have a large amount of blocks?

s-lora would increase number of blocks(by use a small block size) and window attention would make the block eviction more complex. I failed to find a good solution.

grimoire avatar Apr 07 '24 08:04 grimoire

@ispobock Do they support window attention? How do they evict blocks? Would it take a long time if we have a large amount of blocks?

s-lora would increase number of blocks(by use a small block size) and window attention would make the block eviction more complex. I failed to find a good solution.

In mistralai-sf24/hackathon, sliding window has been removed https://x.com/mistralailabs/status/1771670765521281370

zhyncs avatar Apr 07 '24 08:04 zhyncs

And I think this approach is acceptable for now. https://github.com/InternLM/lmdeploy/blob/137d106a0cba4f5a0297ac07959256c57435433b/lmdeploy/pytorch/config.py#L63-L66

zhyncs avatar Apr 07 '24 08:04 zhyncs

@grimoire

When do we need general cache?

For example seq1: xxxxyyyyzzzz, seq2: yyyyzzzz, 4 tokens per block, for general cache, seq2 may use the last 2 cached blocks of seq1. It's mentioned in vllm's design, but I'm not sure the real usage and implementation.

How do they evict blocks? Would it take a long time if we have a large amount of blocks?

It seems all of them are using reference count + LRU for evict policy.

ispobock avatar Apr 07 '24 08:04 ispobock

And I think this approach is acceptable for now.

https://github.com/InternLM/lmdeploy/blob/137d106a0cba4f5a0297ac07959256c57435433b/lmdeploy/pytorch/config.py#L63-L66

ref https://github.com/vllm-project/vllm/pull/2762/files#r1495331586

zhyncs avatar Apr 07 '24 08:04 zhyncs

Sure, let's ignore the sliding window for now.

It seems that the hash map does not bring much benefits to prefix matching. Eviction by blocks takes more time than eviction by node(sort by visit time, update ref-count/visit-time, update sequence status...).

But adding new concept node into the schedule made the code error prone and hard to maintain. Any advice?

grimoire avatar Apr 07 '24 09:04 grimoire

vllm didn't take the radix tree implementation due to the hard maintenance:

Major benefits of this design over a KV block Trie

  • Sometimes, caching is not limited to prefix caching:
    • With Mistral's sliding window attention, we only need to cache the last tokens in the sliding window.
    • With attention sinks, we need to cache the first few tokens and the latest tokens.
  • Maintaining hash table is simpler than maintaining a tree.
  • Extensible to more advanced caching policy (the one above is just an example).

In sglang, actually there is no block concept because the size of each page is equivalent to one token, which simplified the implementation.

ispobock avatar Apr 07 '24 11:04 ispobock

For example seq1: xxxxyyyyzzzz, seq2: yyyyzzzz, 4 tokens per block, for general cache, seq2 may use the last 2 cached blocks of seq1.

In this case

  1. The positional embedding used for yyyyzzzz is offsetted by 4 steps (instead of starting from 0)
  2. xxxx which is involved in the computation of xxxxyyyyzzzz is ignored.

The result will be different from computing yyyyzzzz directly. The outcome maybe similar but you have no guarantee on it.

lzhangzz avatar Apr 07 '24 12:04 lzhangzz

vllm didn't take the radix tree implementation due to the hard maintenance:

Major benefits of this design over a KV block Trie

  • Sometimes, caching is not limited to prefix caching:

    • With Mistral's sliding window attention, we only need to cache the last tokens in the sliding window.
    • With attention sinks, we need to cache the first few tokens and the latest tokens.
  • Maintaining hash table is simpler than maintaining a tree.

  • Extensible to more advanced caching policy (the one above is just an example).

In sglang, actually there is no block concept because the size of each page is equivalent to one token, which simplified the implementation.

Hi @grimoire Do you have any suggestions?

zhyncs avatar Apr 08 '24 03:04 zhyncs

Maintaining hash table is simpler than maintaining a tree.

That's true, especially when block size is not 1. In this PR, node is a wrap of sequence with meta info. I want to share the same block manage code to ease the implementation, but it ... sucks.

I want to try the block-based strategy. Guess it would take a long time to design and prototype since I don't want to break any features that already exist.

grimoire avatar Apr 08 '24 03:04 grimoire

Hi @grimoire I would like to know, is the completion of this PR currently ready for normal use? Thanks.

zhyncs avatar Apr 08 '24 06:04 zhyncs

@zhyncs Yes, this is not a draft.

grimoire avatar Apr 08 '24 07:04 grimoire

ref https://github.com/InternLM/lmdeploy/issues/1407#issuecomment-2044203407

zhyncs avatar Apr 09 '24 06:04 zhyncs

@grimoire We compared the prefix cache implementation for other projects:

  • vllm

    • Hash Table

    • compute hash key for each block: hash(prefix tokens, tokens in this block)

    • block level reuse, if seq1: xxxxyyyy, seq2: xxxxzzzz, seq3 xxxxyyyyzzzz, each block contains 4 tokens, then seq2 can reuse the first block of seq1, seq3 can reuse 2 blocks of seq1

    • now only support prefix cache ( xxxxxoooo ), but plan to support general cache (xxxoooxxxooo) in the future

    • maybe need to consider hash collision

    • Complexity:

      • AssumeN is the number of seq,L is the length of seq
      • Time (Find & Insert): O(N*(L^2)), because compute hash key needs O(L^2), mentioned here
      • Space: O(N*L)
  • rtp-llm

    • Hash Table

    • compute hash key for each seq: hash(tokens in sequence)

    • block level reuse, like vllm

    • Complexity:

      • Time

        • Find: O((N^2)*L), due to token level match
        • Insert: O(N*L)
      • Space: O(N*L)

  • sglang

    • Radix Tree

    • can only support prefix cache ( xxxxxoooo ), cannot support general cache (xxxoooxxxooo)

    • Complexity:

      • Time (Find & Insert): O(N*L)
      • Space: worst O(N*L), if no shared part

After https://github.com/sgl-project/sglang/pull/364, SGLang Radix Tree implementation RPS increased by nearly 10%

zhyncs avatar Apr 18 '24 08:04 zhyncs

Very good discussion here. ref https://github.com/vllm-project/vllm/issues/2614#issuecomment-2116330884

merrymercy avatar May 16 '24 22:05 merrymercy