Generation using cache gives weird sentences
While using cache (past_key_values) during speculative decoding or even autoregressive decoding, the resulting generated tokens might be somewhat weird and non-sense. Because of this behavior, speculative sampling is slowed down (sometimes even being slower than AR decoding).
speculative_generate edits the cache by pruning the last tokens when rejection happens. I first thought the errors came from this.
But the generation is also weird in autoregressive_generate even though the cache is not edited nor pruned.
That leads to think:
- Am I using the
past_key_valueswrongly (putting as forward parameter, and getting the newly KVcache in the output)? - Or is this a problem coming from my transformers/torch versions (latest stable)?
- Or is this an issue from transformers itself?
I will greatly appreciate any help or advices in here! Thanks.
This might be a relevant issue:
https://github.com/huggingface/transformers/issues/26344
-
When the
DynamicCacheis initialized as empty, it appears that later versions oftransformershave a related cache issue: Hugging Face Transformers Issue #27985. You can refer to some of the solutions mentioned in this issue, or simply passNoneto thepast_key_valuesonly during the first generation step. -
When
use_cacheis set to True and the previous dynamic cache is not empty, the model should be provided with only the current input_id, rather than the entire prefix.
Thank you for both of your insights @vdaita and @ShawnLue. I will try in the following days to perform some tests using your comments.
After conducting several tests, I've encountered persistent issues with the past_key_values feature. While the problem is particularly noticeable with the Gemma 2 model, it raises concerns about the feature's overall reliability and compatibility across different models.
Attempted Solutions and Observations
- No cache on first forward pass. Result: No
past_key_valuesare returned. - Passing a cache even empty (Tuple or DynamicCache). Result: Tensor dimension issues occur from the first pass.
Upon reviewing the implementations of various models, it appears that the usage of past_key_values is inconsistent and potentially incorrect across different implementations.
If this feature doesn't work correctly for at least one model, it raises questions about its overall reliability and whether it should be included in its current form.
- Should we wait for corrections to be made by model maintainers?
- Is it feasible to find a way to bypass the existing
past_key_valuesimplementation?
Given the open-source nature of this project and the likelihood that users may focus on specific models, we could leave the task of adapting the cache to individual users who may need it for specific models.
I had a small doubt regarding the cache, while it is not exactly related to this issue, but I wanted to know why we pass the entire input when us_cache=True. Ideally we should pre-fill the model with input and then only send the latest token to generate the next token to reduce the memory. But from my current understanding of the code, with KV cache true the whole input sequence is passed to it. Is this something internally handled by HF transformers to only consider new tokens and ignore the input tokens that are already in cache?
Hello, your project is great, thank you very much! Regarding the abnormal results caused by the use of past_key_values, the possible reason is that there is an error in the way past_key_values is used. When we turn on past_key_values, the model will perform attention calculations on all tokens in input_id and the stored kv values, which will cause errors.
For example, suppose we have entered "i like this", the model outputs the prediction result of "project", and the model stores the kv cache of "i", "like", and "this". When we perform autoregressive calculations and enter the next token "i like this project", the model will calculate the attention results of "i", "like", "this" and "project" with the stored kv cache one by one.
Under normal circumstances, we only need to enter "project". In other words, enter the token without kv cache.
In addition, there seems to be a problem with cache pruning in speculative sampling.
@Aiden-Frost During the time I worked on this, I found out in the HF code that while using the cache the input was pruned to the last position no matter if you provide it with the full input or just the last token. But as @syhzcx mentioned, it seems to be changed (or most likely a bad interpretation from my side). Thank you so much for the insights, I will fix that. Additionally, I'll work on the cache pruning. But as the kvcache might differ from model to model, I'll try to make it the best as possible.