What happens to the total KV length > max-compacity length during response generation?
Hi, thanks for your great work!
It's impressive to compress the long prompt KVs into a constant length. I'm wondering whether the scenario here also consider the case that generation responses > maximum compacity?
It always goes to ln127 only during prefilling stage, and during generation stage it always goes to ln131. Is my understanding correct? https://github.com/FasterDecoding/SnapKV/blob/main/snapkv/monkeypatch/mistral_hijack_4_37.py#L127-L133
Thanks for the question. Our method mainly focused on long-context sequence scenarios where input is usually much longer than output and benefited generation speed. We didn't consider the compression along generation stage. I believe other work like H2O also compress along generation.
Thanks for the question. Our method mainly focused on long-context sequence scenarios where input is usually much longer than output and benefited generation speed. We didn't consider the compression along generation stage. I believe other work like H2O also compress along generation.
Hi! Can you explain the meaning of kv_seq_len in the llama_hijack code? I find myself having trouble understanding it😂. And why is it that the condition in L84 is == while in mistral_hijack this condition is >=? Thank you greatly for your time!
https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/mistral_hijack_4_37.py#L127-L133
https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/llama_hijack_4_37.py#L65-L71
https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/llama_hijack_4_37.py#L84-L90
Am I understanding it correctly that L84 in llama_hijack aims to determine whether the model is at prefilling stage or decoding stage?
Thanks for the question. Our method mainly focused on long-context sequence scenarios where input is usually much longer than output and benefited generation speed. We didn't consider the compression along generation stage. I believe other work like H2O also compress along generation.
Hi! Can you explain the meaning of kv_seq_len in the llama_hijack code? I find myself having trouble understanding it😂. And why is it that the condition in L84 is
==while in mistral_hijack this condition is>=? Thank you greatly for your time!SnapKV/snapkv/monkeypatch/mistral_hijack_4_37.py
Lines 127 to 133 in 82135ce
if key_states.shape[-2] >= kv_seq_len: # [SnapKV] add kv_cluster self.kv_seq_len = kv_seq_len key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups) past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs) else: self.kv_seq_len += q_len key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) SnapKV/snapkv/monkeypatch/llama_hijack_4_37.py
Lines 65 to 71 in 82135ce
if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len if self.kv_seq_len != 0: kv_seq_len += self.kv_seq_len else: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) SnapKV/snapkv/monkeypatch/llama_hijack_4_37.py
Lines 84 to 90 in 82135ce
if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster self.kv_seq_len = kv_seq_len # [SnapKV] register kv_seq_len key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups) past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs) else: self.kv_seq_len += q_len key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)