Sink Cache Attention Scores are strange. CausalMask seems not working.
System Info
transformersversion: 4.41.0- Platform: Linux-5.15.0-67-generic-x86_64-with-glibc2.31
- Python version: 3.10.13
- Huggingface_hub version: 0.23.0
- Safetensors version: 0.4.2
- Accelerate version: 0.27.2
- Accelerate config: - compute_environment: LOCAL_MACHINE - distributed_type: MULTI_GPU - mixed_precision: no - use_cpu: False - debug: False - num_processes: 4 - machine_rank: 0 - num_machines: 1 - gpu_ids: 0,1,2,3 - rdzv_backend: static - same_network: True - main_training_function: main - downcast_bf16: no - tpu_use_cluster: False - tpu_use_sudo: False - tpu_env: [] - dynamo_config: {'dynamo_backend': 'INDUCTOR', 'dynamo_mode': 'default', 'dynamo_use_dynamic': False, 'dynamo_use_fullgraph': False}
- PyTorch version (GPU?): 2.3.0+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?:
- Using distributed or parallel set-up in script?:
Who can help?
No response
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - [X] My own task or dataset (give details below)
Reproduction
dataset:
from datasets import load_dataset
dataset = load_dataset("./xsum", split='test[:1000]')# I load it from local due to unsolved utf8 error. It is exactly XSUM
I concatenate all 'document' text to generate a streaming task(test streamingLLM). The code is trivial but long so omitted.
model: LlamaForCasualLM Weight: llama2-7b-hf
core run code:
def test_run(run_name, sink_len, win_len, stream,
break_len = 8000, get_attn_steps = [10]):
kvcache = SinkCache(win_len, sink_len)
perp_vs_len = []
attn_scores_all = {}
with torch.inference_mode():
for i, input_ids in enumerate(stream):
get_attn = (i in get_attn_steps)
if i * stream.input_len > break_len:
break
output = model(
input_ids.to(device),
past_key_values=kvcache,
use_cache=True,
output_attentions=get_attn,
return_dict=True
)#type: ignore
perp = perplexity(input_ids, output.logits).unsqueeze(0).item()
perp_vs_len.append(((i+1)*input_ids.shape[1], perp))
if get_attn:
attn_scores = output.attentions
attn_scores = {
"layer_0_head_0":attn_scores[0][0,0,:,:],
"layer_2_head_0":attn_scores[2][0,0,:,:],
"layer_5_head_0":attn_scores[5][0,0,:,:],
"layer_10_head_0":attn_scores[10][0,0,:,:],
"layer_-1_head_0":attn_scores[-1][0,0,:,:]
}
attn_scores_all[i] = attn_scores
result = {
'sink_len': sink_len,
'win_len': win_len,
'perp_vs_len': perp_vs_len,
'attn_scores': attn_scores_all
}
torch.save(result, f"./result/{run_name}_sink{sink_len}perp_attn.pt")
return result
stream object produces 100 tokens per iter, like a list(containinng many tokens).
And I plotted the attention scores. However, the upper triangle part of them are not zeros.
sink num = 0(local window):
sink num = 16(for streaming LLM):
Expected behavior
The attention scores matrix for prompt len = 0(no kv cache) is right:
Supplementary:
It works with DynamicCache.
So it must be something wrong with SinkCache and relevant control code.
cc @gante @ArthurZucker
Have not worked on the sink cache so will let @gante answer here!
In cache_utils.py, I noticed that
keys_to_keep = self.key_cache[layer_idx][ :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : ]
might go wrong when -self.window_length + self.num_sink_tokens + key_states.shape[-2] >= 0
Not sure is it relevant
It's been a bit since I worked on this, but I think that -self.window_length + self.num_sink_tokens + key_states.shape[-2] >= 0 is not really possible.
window_lengthis the max. size of the cache, e.g. 1024.num_sink_tokensis some (usually small) positive integer, e.g. 4key_states.shape[-2]is the size of the new additions to the cache.
In the code here: https://github.com/huggingface/transformers/blob/b72752f06830cb6cf8d21c284f68e15faa100c4d/src/transformers/cache_utils.py#L703-L706
We're in the "Shifting cache" phase, i.e. the cache already exists, and now we're adding enough tokens to make it overflow. However, if it already exists, then I think (I'm not 100% on this) we always add 1 new generated token, i.e. key_states.shape[-2] is 1. So I think a non-negative value can only happen if the num_sink_tokens >= window_length - 1, which is not normal behaviour.
However, if it's somehow possible to, when the cache already exists, add a bunch of tokens in one go, then I think it would be possible to mess this up. Then, the keys_to_keep should really be empty (as we're skipping way ahead and keeping no tokens), but the overflow of -self.window_length + self.num_sink_tokens + key_states.shape[-2] >= 0 into the positives is allowing some keys to stay. Then the new tokens will get appended and we'll accidentally get a cache that's too large here: https://github.com/huggingface/transformers/blob/b72752f06830cb6cf8d21c284f68e15faa100c4d/src/transformers/cache_utils.py#L724
But I think that should probably cause a pretty easy-to-spot crash as the cache is now bigger than the window size, which should not be possible.
- Tom Aarsen
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.