SnapKV icon indicating copy to clipboard operation
SnapKV copied to clipboard

What happens to the total KV length > max-compacity length during response generation?

Open PengWenChen opened this issue 1 year ago • 3 comments

Hi, thanks for your great work!

It's impressive to compress the long prompt KVs into a constant length. I'm wondering whether the scenario here also consider the case that generation responses > maximum compacity?

It always goes to ln127 only during prefilling stage, and during generation stage it always goes to ln131. Is my understanding correct? https://github.com/FasterDecoding/SnapKV/blob/main/snapkv/monkeypatch/mistral_hijack_4_37.py#L127-L133

PengWenChen avatar Oct 23 '24 02:10 PengWenChen

Thanks for the question. Our method mainly focused on long-context sequence scenarios where input is usually much longer than output and benefited generation speed. We didn't consider the compression along generation stage. I believe other work like H2O also compress along generation.

WendyH1108 avatar Oct 26 '24 04:10 WendyH1108

Thanks for the question. Our method mainly focused on long-context sequence scenarios where input is usually much longer than output and benefited generation speed. We didn't consider the compression along generation stage. I believe other work like H2O also compress along generation.

Hi! Can you explain the meaning of kv_seq_len in the llama_hijack code? I find myself having trouble understanding it😂. And why is it that the condition in L84 is == while in mistral_hijack this condition is >=? Thank you greatly for your time!

https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/mistral_hijack_4_37.py#L127-L133

https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/llama_hijack_4_37.py#L65-L71

https://github.com/FasterDecoding/SnapKV/blob/82135ce2cc60f212a9ba918467f3d9c8134e163f/snapkv/monkeypatch/llama_hijack_4_37.py#L84-L90

SHA-4096 avatar Mar 03 '25 17:03 SHA-4096

Am I understanding it correctly that L84 in llama_hijack aims to determine whether the model is at prefilling stage or decoding stage?

Thanks for the question. Our method mainly focused on long-context sequence scenarios where input is usually much longer than output and benefited generation speed. We didn't consider the compression along generation stage. I believe other work like H2O also compress along generation.

Hi! Can you explain the meaning of kv_seq_len in the llama_hijack code? I find myself having trouble understanding it😂. And why is it that the condition in L84 is == while in mistral_hijack this condition is >=? Thank you greatly for your time!

SnapKV/snapkv/monkeypatch/mistral_hijack_4_37.py

Lines 127 to 133 in 82135ce

if key_states.shape[-2] >= kv_seq_len: # [SnapKV] add kv_cluster self.kv_seq_len = kv_seq_len key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups) past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs) else: self.kv_seq_len += q_len key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) SnapKV/snapkv/monkeypatch/llama_hijack_4_37.py

Lines 65 to 71 in 82135ce

if hasattr(self, "kv_seq_len"): #[SnapKV] add kv_seq_len if self.kv_seq_len != 0: kv_seq_len += self.kv_seq_len else: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) SnapKV/snapkv/monkeypatch/llama_hijack_4_37.py

Lines 84 to 90 in 82135ce

if key_states.shape[-2] == kv_seq_len: # [SnapKV] add kv_cluster self.kv_seq_len = kv_seq_len # [SnapKV] register kv_seq_len key_states_compress, value_states_compress = self.kv_cluster.update_kv(key_states, query_states, value_states, attention_mask, self.num_key_value_groups) past_key_value.update(key_states_compress, value_states_compress, self.layer_idx, cache_kwargs) else: self.kv_seq_len += q_len key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

SHA-4096 avatar Mar 03 '25 17:03 SHA-4096