Use 'Cache' Class in past_key_values for transformers
What behavior of the library made you think about the improvement?
See bottom of https://github.com/huggingface/transformers/releases/tag/v4.36.0 Also, a static cache implementation is on its way. #https://github.com/huggingface/transformers/pull/27931
Currently I am working on a ExllamaV2 integration, but the problem is that ExllamaV2 already utilizes its own cache instance, without passing it around. The cache is also static, which is not compatible with the dynamic shaped cache that is currently passed in Transformer.forward(). With this new Cache abstraction, it would be easier to implement static cache & ExllamaV2 integration.
How would you like it to behave?
It's not hard to replace current 'past_key_values' to the new 'Cache' instance, as it is replacable with the newest version of transformers. The problem (or the design choice) is whether make the model hold the Cache instance, or make a new Cache for every call.
Could collide with #441.
Thank you for opening an issue! I have a few very naive questions and remarks:
- Does that mean that HF transformers is going to cache KV values by default?
- We were planning on implementing a better and more general KV-caching solution based on tries to minimize the memory impact for Beam Search, not sure that will be necessary anymore after this change in
transformers? - If you look at the Mamba integration you will see that you can return
Nonefor the KV cache and can implement the__call__andforwardmethod to catch the KV values without using them. That's probably what you'll need to do for the ExLlamaV2 integration (which would be awesome to have!)
- No, HF transformers does not cache KV values by default. The new Cache class introduces a structure that could be modified for advances uses, instead of just a tuple containing KV values.
Rough code would be something like this:
from transformers import Cache, DynamicCache
class Transformer:
def __call__(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
past_key_values: Optional[Cache] = None,
) -> torch.FloatTensor:
if not past_key_values:
past_key_values = DynamicCache()
# returned kv_cache type is also Cache, not tuple
logits, kv_cache = self.forward(input_ids, attention_mask, past_key_values)
next_token_logits = logits[..., -1, :]
return next_token_logits, kv_cache
- It looks like #441 is implementing a disk cache, but I think this is not optimal because you are moving GPU tensors to CPU every time the model generates. A GPU cache that resides inside the model instance would be more faster, with a slight VRAM increase. We would have to compare performance with wallclock speed / VRAM to check this out.
- Thanks, I will upload a draft PR soon.
Hey there @kimjaewon96 @rlouf 👋
I'm one of devs involved in the cache rework in transformers! In a nutshell, we are implementing the support for arbitrary cache structures (as long as they support a few basic methods), which you can use as follows:
cache = YourCacheClass(your_cache_kwargs)
gen_out = model.generate(input_ids, do_sample=False, max_new_tokens=512, past_key_values=cache)
For now, we have a dynamic shaped cache class (the cache that has been the default for a long time) and sink caches (= infinite length generation). We are working on a fixed size cache at the moment, which will boost the throughput of most generate-compatible models, and I'm sure the community will find new clever caching strategies!
I'd love to hear some feedback, if you have any 🤗
This looks awesome and seems to be the way to go! We also have some ideas around cache management that are complementary to yours and it would be nice to try to make it work with your new abstraction.
One quick question @gante: we have our own generation layer in Outlines and don't use generate. Can I use a Cache instance for past_key_value and it will work out of the box, i.e. does the model instance takes care of updating the instance, or do I need to do something to update it manually?
@rlouf the model instance will take care of updating the cache instance :D
Please note the verb tense, future. At the moment, only Llama + static cache works this way. We are working on expanding this mechanism to all cache types/models, but it's hard for me to provide a time estimate. I expect Llama + all types of caches to be ready next version (v4.39), with other key models being enabled in the subsequent version (v4.40).
A manual update of the cache will also work, with the advantage that it works now :)