mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

More cache improvements

Open awni opened this issue 1 year ago • 1 comments

Sorry for the large diff. There's a lot of boiler plate / moving stuff around which accounts for most of it.

The main bits are:

  • Fix RotatingKVCache for alternating chat, response use case
  • Enable prompt caching for all types (not just KVCache)
  • Unify APIs and cache types in a single file for ease of use / consistency.
  • Chat mode allows prompt caching for efficiency. Example here.
  • Add a bunch of tests.

Closes #1000

awni avatar Oct 05 '24 21:10 awni

I also added a chat command to MLX LM which is a good use case for the prompt cache re-use. The example is kind of fun to play with:

mlx_lm.chat

Then you can just chat with the model and it preserves the history and doesn't do any prompt recomputations..

[INFO] Starting chat sessiong with mlx-community/Llama-3.2-3B-Instruct-4bit. To exit, enter 'q'.
>> Hi, my name is Awni!
Hi Awni! It's nice to meet you. Is there something I can help you with or would you like to chat?
>> What's the tallest mountain in the world?
The tallest mountain in the world is Mount Everest, which is located in the Himalayas on the border between Nepal and Tibet, China. It stands at an elevation of 8,848 meters (29,029 feet) above sea level.
>> Do you remember my name?
Yes, your name is Awni.
>> Nice talking with you!
It's great to chat with you too, Awni! Is there anything else you'd like to talk about or ask about?
>> 

awni avatar Oct 07 '24 04:10 awni

I am wondering what is ht point of the extra state in the KV cace? Is anybody using it now? Is there any reason it is set to the empty string instead of None?

It's my least favorite thing in this diff, but I didn't think of a cleaner solution yet (if you have ideas I'm all 👂 )

It is used only for the RotatingKVCache so that we can save the cache.offset and cache._idx. Otherwise we don't get the right/same behavior when serializing and deserializing that cache.

The reason I made it a string and not None is because that simplified saving it in safetensors metadata. So downstream code just does something like dict(tree_flatten([c for c in cache.state[1]])).

awni avatar Oct 07 '24 18:10 awni

Huh, not sure how I managed to miss it in the RotatingKVCache...

I think this should be changed because the state is what we evaluate from the caches so this is why I was confused with the string. The most minor change that would be imho significantly better would be to split it into state and serialization_state. It is a bit verbose but at least it separates the two types of information cleanly.

Nothing changes much. Line 47 would do sth like

cache_data = [c.state for c in cache]
cache_info = [c.serialization_state for c in cache] # do we also want type(self).__name__ here?

and lines 75 would change to

for c, state, serialization_state in zip(cache, arrays, info):
    c.state = state
    c.serialization_state = serialization_state

The rest remains the same... Wdyt?

angeloskath avatar Oct 07 '24 18:10 angeloskath

Yea I thought about a separate property.. and/or overriding __getstate__ and __setstate__. The main downside I didn't like is that all the caches needed to implement it.. but maybe the right call is to check if the attribute exists to avoid that. I think you're right it could be cleaner even if a little more verbose.

awni avatar Oct 07 '24 19:10 awni

I added a small base class that implements the empty meta state and makes the load/save code a tad bit cleaner? Should I push it on top or we are avoiding base classes for some reason?

Also just played with the chat command it is an absolute joy to use :-)

angeloskath avatar Oct 07 '24 20:10 angeloskath

I added a small base class that implements the empty meta state and makes the load/save code a tad bit cleaner? Should I push it on top or we are avoiding base classes for some reason?

No reason at this point, please send the diff!

awni avatar Oct 07 '24 20:10 awni

Ok, I tested prompt caching with a few different models / cache types and it seems to work well. I'm going to merge this.

As a follow up we should consider:

  • Making a way to serialize the chat context
  • Adding a chat endpoint to mlx_lm.server with prompt caching

awni avatar Oct 08 '24 03:10 awni

Awesome work, thanks for fixing it!

zcbenz avatar Oct 08 '24 10:10 zcbenz

So excited this got merged 😄

mark-lord avatar Oct 08 '24 15:10 mark-lord