kv_seq_len bug?
if kv_seq_len > local_branch + global_branch and use_lambda_mask:
past_key_value = (
torch.cat([
key_states[..., :global_branch, :],
key_states[..., -local_branch:, :],
], dim=-2),
torch.cat([
value_states[..., :global_branch, :],
value_states[..., -local_branch:, :],
], dim=-2),
key_position_ids[..., :local_branch + global_branch]
) if use_cache else None
Code in models/llama.py lines 144-155 does not update the kv_seq_len, but updates the past_key_value?
Hi chenlidar. Thanks for your interest!
The codes you are looking at seem to come from an older version of our code. Note in this session, kv_seq_len is only associated with key_states, value_states which will be actively used inside the function, and has nothing to do (and is not responsible for monitoring) past_key_value.
past_key_value is only used as an information pool in later function calls, and we do not actually care about their precise lengths. Only when they are concatenated to key_states and value_states like line 133-140 in the same file will kv_seq_len be modified accordingly.
p.s., These lines have been removed in our later versions to accomodate to the newer Transformer versions.
We hope this resolves your question. Please let us know if you have further questions!