The first generation token output sees the whole cache key and value
https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/mistral_hijack_4_37.py#L130
Hi there~ Thanks for your great work! The past_key_value in L130 does update the new compressed key and value. However, the first generation tokens(L168) are still generated with full cache key and value after the prompt compression. https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/mistral_hijack_4_37.py#L168-L176 Is this a bug?
Because this is during the prefill stage, it is unrelated to kv compression, so full kv is used for computation.
To my understanding, the first generation token is the last output logit of prefilling stage.
So the first token of the model response comes from the attn_output here right?
https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/mistral_hijack_4_37.py#L168-L176
If so, then the first generation(predict) token sees the whole KV from input prompt. If not, what's the input token of the first generation token after KV compressing? There must exist a input token to become hidden states and predict the first response token right?
Hello, Has there been a resolution / more discussion on this?
I think the simple fix is to do this here:
key_states, value_states = past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs)
Thanks!