transformers
transformers copied to clipboard
Quantized KV Cache
What does this PR do?
An implementation of quantized cache with quanto library. Introduces a new CacheConfig to store cache related arguments and a new cache class QuantoQuantizedCache. The implementation is based partially on the KIVI paper, but in this case we do a per-token quantization for both: keys and values.
Example usage:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, attn_implementation="eager").to("cuda:0")
inputs = tokenizer("Hello, how are you?", truncation=True, return_tensors="pt").to(model.device)
out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized")
out_fp16 = model.generate(**inputs, do_sample=False, max_new_tokens=20)
print(f"text with quant cache: {tokenizer.batch_decode(out)}")
print(f"text with fp16 cache: {tokenizer.batch_decode(out_fp16)}")
Perplexity plots
Here the results are different from what we got earlier because I was calculating perplexity in one forward pass, by quantizing and then dequantizing all keys and values. The new script uses cache object and calculates pplx per new token.Eval on LongBench (scripts taken from LongBench repo)
This is to compare with the KIVI method, since they did the same evals on all datasets from LongBench.| Dataset | KIVI 16fp | KIVI int2 | Our fp16 | Our int4 | Our int2 |
|---|---|---|---|---|---|
| TREC | 63.0 | 67.5 | 63.0 | 63.0 | 55.0 |
| SAMSum | 41.12 | 42.18 | 41.12 | 41.3 | 14.04 |
I cannot find KIVI results on all of the LongBench, so here will be only transformers version.
| Dataset | fp16 | int4 | int2 |
|---|---|---|---|
| TriviaQA | 84.28 | 84.76 | 63.64 |
| HotPotQA | 30.08 | 30.04 | 17.3 |
| Passage_retrieval_en | 8.5 | 9.5 | 4.82 |
Memory vs Latency plots
Same old plots showing memory consumption and latency for differeny cache types:
As we discussed quantized cache can be started to be integrated to the library, given the results we got so far. All the possible speed optimizations/pre-fill stage optimizations can be done further, as we will be getting feedback from the community.
So, I would like to get a review on the PR :)
Thanks for the comments!
except for guarding quanto imports (also I would say safer to make local imports whenever possible - e.g. at QuantCache init)
Okey noted!
You raised a concern about switching between cache implementations - I made an attempt while ago: https://github.com/huggingface/transformers/pull/29030 that got stale (😅 ) maybe that PR might solve your concern?
I love the generalized cache implementation idea. Not sure how this will work on overall API level, given that Joao and Arthur are working on changing cache thing. I'll let Joao to decide about that
Maybe we could also track models that support quant cache with a private attribute _supports_quant_cache in xxxPreTrainedModel - what do you think?
Hmm, Actually quant cache should be supported abywhere dynamicCache is, that means everything except for old models like bart/t5. Yeah I think we can add it for explicitness, until the cache API is refactored to be same everywhere
Thanks !
Hmm, Actually quant cache should be supported abywhere dynamicCache is, that means everything except for old models like bart/t5. Yeah I think we can add it for explicitness, until the cache API is refactored to be same everywhere
Ok that's great if that's the case then, i would say no need for that !
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@gante added benchmark results on the PR description. Right now int4 has almost same performance as fp16, sometimes a bit better. Also added some comparison with the KIVI paper.
(CI needs fixing -- possibly a simple make fix-copies)
Just curious, would this new cache work with torch.compile?
Nope, this one specifically no as it inherits from Dynamic cache, but another implementation based on static cache could. Compile is not super happy with if else and device placements espacially if it's input dependent (here depends on the length of the processed input)
@ArthurZucker @ydshieh: "torch.compile with quanto is only supported for 8 bits quantization for now" (from @SunMarc, on a related conversation on slack)
I made the KV cache work with HQQ as a backend. It can be simply plugged in if a user writes their own "CacheClass". I am not planning to add it now as it needs more evaluation and experiments, but wanted to show how anyone can add more backends. Do you think I should continue experimenting with HQQ or we can simply put the below code as example for users?
BTW, if we were to actually support more cache quant classes in the library, maybe we'll need to change the current QuantCache API a bit to be more versatile.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from hqq.core.quantize import Quantizer as HQQQuantizer
class HQQQuantizedCache(DynamicCache):
def __init__(
self,
nbits: int = 4,
axis: int = 0,
q_group_size: int = 64,
residual_length: int = 128,
compute_dtype: torch.dtype = torch.float16,
device: str = "cpu",
) -> None:
if nbits not in [2, 4, 8]:
raise ValueError(f"`nbits` has to be one of [`2`, `4`, `8`] but got {nbits}")
if axis not in [0, 1]:
raise ValueError(f"`axis` has to be one of [`1`, `2`] but got {axis}")
self._quantized_key_cache: List[Tuple[torch.Tensor, Dict]] = []
self._quantized_value_cache: List[Tuple[torch.Tensor, Dict]] = []
self.nbits = nbits
self.axis = axis
self.residual_length = residual_length
self.q_group_size = q_group_size
self.compute_dtype = compute_dtype
self.quantizer = HQQQuantizer
self.device = device
super().__init__()
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
if len(self.key_cache) <= layer_idx:
q_key, meta_key = self._quantize(key_states.contiguous())
self._quantized_key_cache.append((q_key, meta_key))
q_value, meta_value = self._quantize(value_states.contiguous())
self._quantized_value_cache.append((q_value, meta_value))
self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
keys_to_return, values_to_return = key_states, value_states
else:
quant_key, meta_key = self._quantized_key_cache[layer_idx]
dequant_key = self.quantizer.dequantize(quant_key, meta_key)
quant_value, meta_value = self._quantized_value_cache[layer_idx]
dequant_value = self.quantizer.dequantize(quant_value, meta_value)
keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]
keys_to_return = torch.cat(keys_to_return, dim=-2)
values_to_return = torch.cat(values_to_return, dim=-2)
if (
self.key_cache[layer_idx].dim() == 4
and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
):
q_key, meta_key = self._quantize(keys_to_return.contiguous())
self._quantized_key_cache[layer_idx] = (q_key, meta_key)
q_value, meta_value = self._quantize(values_to_return.contiguous())
self._quantized_key_cache[layer_idx] = (q_value, meta_value)
self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return keys_to_return, values_to_return
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
# since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
# updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
# this part of code otherwise fails when used to verify attn_weight shape in some models
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
def _quantize(self, tensor):
qtensor, meta = self.quantizer.quantize(
tensor,
axis=self.axis,
device=self.device,
compute_dtype=self.compute_dtype,
nbits=self.nbits,
group_size=self.q_group_size,
)
meta["compute_dtype"] = self.compute_dtype
return qtensor, meta
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
tokenizer.pad_token_id = tokenizer.eos_token_id
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16, attn_implementation="eager", device_map = "auto")
inputs = tokenizer("I like rock music because" return_tensors="pt").to(model.device)
out = model.generate(
**inputs,
do_sample=False,
max_new_tokens=50,
past_key_values=HQQQuantizedCache(
nbits=2,
axis=1, # 2bit with axis=0 generates garbage
compute_dtype=torch.float16,
device=model.device
),
)
print(f"text with HQQ backend: {tokenizer.batch_decode(out)}")
I think that making the cache class versatile is great to have people build on top of it, without necessarily including anythinig in transformers! But this can comme in a follow up PR
@ArthurZucker yes, making a versatile cache class will go on another PR. In that case we can leave quanto as the only choice available, and the rest can be implemented by users themselves
sounds good
@ArthurZucker @gante I made a few changes from the last review:
- Now we support HQQ and quanto (quanto by default as it is a bit faster, we'll work on using optimized kernels later). For that we have a base "QuantizedCacheClass" and all quantization methods can make their own class from it by overriding the
_quantizeand_dequantizemethods. - Added more kwargs to the config, so the users can indicate axis to quantize for keys and values separately, and have more control over the process
- Added
_supports_quantized_cachemainly because of Jamba. Jamba comes out to_supports_cache_classbut in "modeling" it checks for attribute that is not available for all cache classes (here)
I added a new usage ex in the description and will rework a bit the blogpost, given that now support HQQ. This PR is ready for the second review!
Cool, merging 🤞🏻
Ran slow tests in quantization and generation locally, everything is passing.
I am wondering if we can have this works together #30862. If so, we can probably get further more speedup!
@zucchini-nlp Could you share the simplest code snippet that you use for this PR to measure the runtime (latency)? I can try to incorporate this with #30862 🙏
@ydshieh This PR actually results in slow-down because of quantization 😅 But we can check the memory usage probably. Here is a script I used, but you'd have to replace QuantCache with QuantoQuantizedCache because the evaluation was done on an older commit
OK. Thanks for sharing, so this PR is more about memory instead of speed.