mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Handle longer prompt/generation

Open awni opened this issue 1 year ago • 4 comments

This PR introduces two optimizations to handle much longer prompts and generations:

Longer prompts

  • During prompt processing / cache prefiling, split the prompt into segments (512 is about optimal).
  • Adds a state method to the KV cache classes so we can mx.eval only the cache state

With stepped prefill:

Prompt: 4232 tokens, 3138.508 tokens-per-sec
Peak memory: 2.640 GB

Without stepped prefill:

Prompt: 4232 tokens, 2300.630 tokens-per-sec
Peak memory: 8.422 GB

Longer generations

Allow rotating KV cache to enable infinite generations. Toggle with the flag --max-kv-size=512 (or whatever value).

This is similar (probably identical) to the behavior in llama.cpp. The first keep tokens (default 4) of the prompt are always kept, everything else gets overwritten in the circular buffer. Based on this paper.

With this technique you can generate indefinitely at fixed memory use.

awni avatar Aug 12 '24 18:08 awni

looks great!

Jckwind avatar Aug 12 '24 19:08 Jckwind

Back to draft for a few. The rotating buffer doesn't play well with the step prefill for long prompts.. so that needs some work.

awni avatar Aug 12 '24 23:08 awni

This is awesome!

Great work @awni 🔥

Blaizzy avatar Aug 13 '24 23:08 Blaizzy

Ok so I think this can be reviewed and merged. A little note on the "infinite KV cache":

For simplicity it separates the cache growth into two stages: prefill (i.e. prompt processing) and generation. During generation we assume the updates to the cache are one time-step at a time. During the prefill stage, they can be any number.

The invariant is that every new token attends to at least max_size - 1 previous tokens (including the first keep=4).

During prefill, to make this happen the KV cache can grow as big as max_size + step_size - 1 (where step_size is the prefill update step size. To keep things simple we don't use a circular buffer during this stage as the masking can get a bit complicated and the code is not so well setup for that. Instead, the prefill stage simply grows by triming the old cache and concatenating the update as a suffix to create the new cache.

During generation it uses a circular buffer with a fixed size at max_size and maintains an index into the next slot to write into the buffer.

awni avatar Aug 13 '24 23:08 awni

Thanks for the great work, I have a quick question: will the circular buffer maintain the logical order of the cache? Please correct me if I'm wrong, but it seems in the code we are not maintaining the logical order of the cache. For example, if we start with an initial cache of [1, 2, 3, 4, 5, 6], and keep 1 and 2 as attention sink and the cache during generation, it becomes [1, 2, 7, 4, 5, 6]. self-attention is using [1, 2, 7, 4, 5, 6] instead of [1 ,2 ,4 ,5 ,6 ,7]..

mzbac avatar Aug 17 '24 05:08 mzbac

You're understanding is exactly right. The logical order doesn't matter in this case, the output is the same since self-attention is invariant to permutations in its input. (Note the RoPE addition is done before the key/values get put into the cache, so the position encodings are still valid).

awni avatar Aug 17 '24 05:08 awni

You're understanding is exactly right. The logical order doesn't matter in this case, the output is the same since self-attention is invariant to permutations in its input. (Note the RoPE addition is done before the key/values get put into the cache, so the position encodings are still valid).

Thanks for the detailed explanation, really appreciate it. It makes a lot of sense to me now ❤️

mzbac avatar Aug 17 '24 05:08 mzbac