LLMs-from-scratch icon indicating copy to clipboard operation
LLMs-from-scratch copied to clipboard

Fix bug in masking when kv cache is used.

Open martinzwm opened this issue 6 months ago • 3 comments

Thank you for creating this project, I learned a lot from it!

There seems to be a small bug during masking when kv cache is enabled:

  • W/o kv cache, mask_bool = self.mask.bool()[:num_tokens, :num_tokens] yields to intended results.
  • W/ kv cache, num_tokens would be set to 1, and mask_bool would be a tensor of shape (1, 1). However, we want the mask_bool to be a tensor of shape (1, num_tokens_K).

The following changes address this bug.

martinzwm avatar Jun 22 '25 21:06 martinzwm

Thanks for updating the masking. I just added some tests to make the equivalent easier... it looks like the updated masking now creates a mismatch between the base model and the cached model.

rasbt avatar Jun 22 '25 22:06 rasbt

I may have to rethink this when my brain is a bit fresher tomorrow morning, but I think the original code is correct because we don't recompute the older tokens, just the new token. So in that case the mask is technically not needed as there are no future tokens.

rasbt avatar Jun 22 '25 23:06 rasbt

I think you are right! With the original implementation, during KV cache:

  • mask_bool = self.mask.bool()[:num_tokens, :num_tokens] does give you a mask with shape (1, 1). During attn_scores.masked_fill_(mask_bool, -torch.inf), mask_bool get automatically broadcasted to the shape of attn_scores, which is (1, num_tokens_K).
  • Alternatively, we could use mask_bool = self.mask.bool()[self.curr_pos : self.curr_pos + num_tokens, :num_tokens_K] to achieve the same thing without relying on the broadcasting. Note that self.ptr_current_pos needs to be added, like in the gpt_with_kv_cache_optimized.py.

martinzwm avatar Jun 23 '25 13:06 martinzwm

Thanks for the suggestion. I think it's a good idea here to make the code more explicit.

rasbt avatar Jun 23 '25 17:06 rasbt