lit-llama icon indicating copy to clipboard operation
lit-llama copied to clipboard

Question about FlashAttention and KV-cache

Open KnowingNothing opened this issue 2 years ago • 0 comments

Hi, I notice that you use KV-cache with FlashAttention in CausalSelfAttention. As far as I am concerned, FlashAttention has already implemented the causal self-attention in its kernels, which means for Q [batch, head, seq_len, d_k] x K [batch, head, seq_len, d_k] , only the lower half of the lower triangular result matrix is computed. But in lit_llama/model.py, the CausalSelfAttention uses KV-cache, so only Q [batch, head, 1, d_k] x K [batch, head, seq_len, d_k] is passed to FlashAttention. I think this may cause FlashAttention to compute only the first element of result matrix. I just want to confirm if this is correct. Please correct me if I am wrong. Could anybody kindly explain the reason? Thanks.

KnowingNothing avatar Jul 23 '23 09:07 KnowingNothing