SnapKV icon indicating copy to clipboard operation
SnapKV copied to clipboard

The first generation token output sees the whole cache key and value

Open PengWenChen opened this issue 1 year ago • 3 comments

https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/mistral_hijack_4_37.py#L130

Hi there~ Thanks for your great work! The past_key_value in L130 does update the new compressed key and value. However, the first generation tokens(L168) are still generated with full cache key and value after the prompt compression. https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/mistral_hijack_4_37.py#L168-L176 Is this a bug?

PengWenChen avatar Jan 06 '25 08:01 PengWenChen

Because this is during the prefill stage, it is unrelated to kv compression, so full kv is used for computation.

1028xxL avatar Jan 07 '25 03:01 1028xxL

To my understanding, the first generation token is the last output logit of prefilling stage. So the first token of the model response comes from the attn_output here right? https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/mistral_hijack_4_37.py#L168-L176

If so, then the first generation(predict) token sees the whole KV from input prompt. If not, what's the input token of the first generation token after KV compressing? There must exist a input token to become hidden states and predict the first response token right?

PengWenChen avatar Jan 07 '25 06:01 PengWenChen

Hello, Has there been a resolution / more discussion on this?

I think the simple fix is to do this here: key_states, value_states = past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)

Thanks!

akhauriyash avatar Jan 27 '25 13:01 akhauriyash