Fix bug in masking when kv cache is used.
Thank you for creating this project, I learned a lot from it!
There seems to be a small bug during masking when kv cache is enabled:
- W/o kv cache,
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]yields to intended results. - W/ kv cache,
num_tokenswould be set to 1, andmask_boolwould be a tensor of shape (1, 1). However, we want themask_boolto be a tensor of shape (1, num_tokens_K).
The following changes address this bug.
Thanks for updating the masking. I just added some tests to make the equivalent easier... it looks like the updated masking now creates a mismatch between the base model and the cached model.
I may have to rethink this when my brain is a bit fresher tomorrow morning, but I think the original code is correct because we don't recompute the older tokens, just the new token. So in that case the mask is technically not needed as there are no future tokens.
I think you are right! With the original implementation, during KV cache:
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]does give you a mask with shape (1, 1). Duringattn_scores.masked_fill_(mask_bool, -torch.inf),mask_boolget automatically broadcasted to the shape ofattn_scores, which is (1, num_tokens_K).- Alternatively, we could use
mask_bool = self.mask.bool()[self.curr_pos : self.curr_pos + num_tokens, :num_tokens_K]to achieve the same thing without relying on the broadcasting. Note thatself.ptr_current_posneeds to be added, like in thegpt_with_kv_cache_optimized.py.
Thanks for the suggestion. I think it's a good idea here to make the code more explicit.