Llama patch for FlashAttention support fails with use_cache
I came across your llama_patch.py when looking to patch Llama for inference myself and unless I'm doing something wrong the implementation fails when use_cache=True and past_key_value is not None.
Specifically during geneartion with use_cache=True in this line query_states will have sequence length 1 while key_states and value_states will have length 1 + past_key_value[0].shape[-2] and thus these tensors won't stack.
https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/05d83eaa3c2ad6088227fa26dffb097e06439aef/training/utils/llama_patch.py#L76C3-L76C3
I think this is also the other llama patches referenced in the comments don't support flash attention + kv cache at the same time. Not sure if there's a clever workaround?
Hey @qmdnls,
It could be very true what you say. I created the patch only for training, where you use gradient checkpointing and no cache.
If you are interested in inference i recommend checking text-generation-infernece
I see, no worries! Just came across this and thought I would let you know since the patch seemed to specifically implement the case with past_key_value unlike the other referenced implementations.
Thanks for the pointer, I will have a look!