Handle longer prompt/generation
This PR introduces two optimizations to handle much longer prompts and generations:
Longer prompts
- During prompt processing / cache prefiling, split the prompt into segments (512 is about optimal).
- Adds a
statemethod to the KV cache classes so we canmx.evalonly the cache state
With stepped prefill:
Prompt: 4232 tokens, 3138.508 tokens-per-sec
Peak memory: 2.640 GB
Without stepped prefill:
Prompt: 4232 tokens, 2300.630 tokens-per-sec
Peak memory: 8.422 GB
Longer generations
Allow rotating KV cache to enable infinite generations. Toggle with the flag --max-kv-size=512 (or whatever value).
This is similar (probably identical) to the behavior in llama.cpp. The first keep tokens (default 4) of the prompt are always kept, everything else gets overwritten in the circular buffer. Based on this paper.
With this technique you can generate indefinitely at fixed memory use.
looks great!
Back to draft for a few. The rotating buffer doesn't play well with the step prefill for long prompts.. so that needs some work.
This is awesome!
Great work @awni 🔥
Ok so I think this can be reviewed and merged. A little note on the "infinite KV cache":
For simplicity it separates the cache growth into two stages: prefill (i.e. prompt processing) and generation. During generation we assume the updates to the cache are one time-step at a time. During the prefill stage, they can be any number.
The invariant is that every new token attends to at least max_size - 1 previous tokens (including the first keep=4).
During prefill, to make this happen the KV cache can grow as big as max_size + step_size - 1 (where step_size is the prefill update step size. To keep things simple we don't use a circular buffer during this stage as the masking can get a bit complicated and the code is not so well setup for that. Instead, the prefill stage simply grows by triming the old cache and concatenating the update as a suffix to create the new cache.
During generation it uses a circular buffer with a fixed size at max_size and maintains an index into the next slot to write into the buffer.
Thanks for the great work, I have a quick question: will the circular buffer maintain the logical order of the cache? Please correct me if I'm wrong, but it seems in the code we are not maintaining the logical order of the cache. For example, if we start with an initial cache of [1, 2, 3, 4, 5, 6], and keep 1 and 2 as attention sink and the cache during generation, it becomes [1, 2, 7, 4, 5, 6]. self-attention is using [1, 2, 7, 4, 5, 6] instead of [1 ,2 ,4 ,5 ,6 ,7]..
You're understanding is exactly right. The logical order doesn't matter in this case, the output is the same since self-attention is invariant to permutations in its input. (Note the RoPE addition is done before the key/values get put into the cache, so the position encodings are still valid).
You're understanding is exactly right. The logical order doesn't matter in this case, the output is the same since self-attention is invariant to permutations in its input. (Note the RoPE addition is done before the key/values get put into the cache, so the position encodings are still valid).
Thanks for the detailed explanation, really appreciate it. It makes a lot of sense to me now ❤️