mlx-swift-examples
mlx-swift-examples copied to clipboard
Feature: prompt caching (Fixes #310)
Fixes: https://github.com/ml-explore/mlx-swift-examples/issues/310
Currently there is no way to persist the cache between calls to generate().
This is a relatively simple fix by adding [KVCache] parameters to the generate() functions which are then passed to the TokenIterator.
Trim functions have been added to the KVCache protocol and implemented in KVCacheSimple. Even though not strictly necessary for caching, it is not uncommon for the new prompt to be partially inconsistent with the cache either through tokenizer inconsistencies or recent messages being intentionally manipulated (e.g. removing a <think> block).
An example of how to implement a prompt cache has been added to MLXChatExample. The time to first token (TTFT) is now also displayed which is helpful to see the performance improvement from caching.
The prompt cache is implemented in the PromptCache class which is @unchecked Sendable which allows it to be used within the ModelContainer context. Currently there is no isolation on the KVCache in PromptCache.
Note that if using the AsyncStream versions of generate() there is no way to return token ids, so the newly generated response can't be added to the cache, and it will be reprocessed again on the next message. Perhaps the token could be added to the Generation.chunk?
Overall I like this direction. I think it needs:
- finish the borrowing of the KVCache (implement a visitor that holds a lock so we can satisfy the
unchecked Sendable) - remove the debug printing
What do you think of this with regard to #330? That encapsulates the KVCache which is part of this PR, but I think the cache manipulation is still key.
What do you think of this with regard to #330? That encapsulates the KVCache which is part of this PR, but I think the cache manipulation is still key.
Yes, passing KVCache to generate is now there which is the critical part. It at least allows developers to add their own prompt cache handling.
Just a comment on the streamlined approach and the EvaluateLLM example which defaults to Qwen3: Qwen specifically state that <think> blocks should not be included in the chat history. If they aren't included then the KVCache needs to be trimmed because it will have included the most recent <think> block in the last generated response. So technically the example should have cache trimming, but I don't think it is an issue for EvaluateLLM. ChatSession should probably have it though?
What do you think of this with regard to #330? That encapsulates the KVCache which is part of this PR, but I think the cache manipulation is still key.
Yes, passing
KVCachetogenerateis now there which is the critical part. It at least allows developers to add their own prompt cache handling.Just a comment on the streamlined approach and the
EvaluateLLMexample which defaults to Qwen3: Qwen specifically state that<think>blocks should not be included in the chat history. If they aren't included then theKVCacheneeds to be trimmed because it will have included the most recent<think>block in the last generated response. So technically the example should have cache trimming, but I don't think it is an issue forEvaluateLLM.ChatSessionshould probably have it though?
Just to clarify, if the <think> blocks are removed then the LLM will respond in the next turn with no knowledge of any prior <think> blocks.
If the KVCache isn't trimmed then all of the prior <think> blocks will be there and the LLM will be aware of all prior <think> blocks.
So the Qwen3 thinking example may behave differently to examples that remove the <think> blocks. I haven't tested both ways to see how they respond differently.
Reference on the <think> tags: https://api-docs.deepseek.com/guides/reasoning_model
I think this is a little bit complicated, but goes something like this:
- iterator keeps track of start index when doing generation
- produces full output
- goes back through tokens and filters out the
<think>section(s) - resets the KVCache to the start index
- prefills the KVCache with the response (like a prompt)
We would have to consider what happens if the caller terminates the iteration early. Maybe it isn't even part of the Iterator (ideally we could compose this).
Anyway, the <think> section will be in the KVCache as a consequence of generation so we have to replay the output to replace what is in there.
Something to consider if the <think> blocks are to be removed at the token level is that the end <\think> tag could form a token with the following response. For example if the response looked like:
<think>
Thinking... thinking...
</think>
Here is my response
It is possible that a token could be >Here, which wouldn't allow the <think> block to be properly removed at the token level.
If the <think> block could be removed at the token level an even better optimisation would be to snip the <think> block out, not just trim. But I don't know enough about LLMs and KV caches to know if this would work.
If we can't edit the cache at the token level, then the entire previous response has to be removed (as it is guaranteed to be on a clean token boundary), and it is pre-filled again as if a new prompt.
Note that the KVCache trimming is extremely efficient, it just edits the offset in the current implementation in the PR:
public func trim(n: Int) -> Int {
let toTrim = min(self.offset, n)
self.offset -= toTrim
return toTrim
}
Yeah, we probably have to handle the <think> tags in decoded token space, trim it out, then re-tokenize.
FWIW mlx-lm (python side) does not have this capability yet.