[Question] Flash attention only applies to prefilling stage
I have a question arising from reading the code. I notice that in ~/lightllm/models/llama2/layer_infer/transformer_layer_infer.py, the flash attention is only applied to the prefilling stage, i.e. the context_attention_fwd, but not to the decoding stage, i.e. token_att_fwd. Am I correct in this understanding?
In principle, token attention doesn't conflict with flash attention. Do you plan to combine them both in the decoding stage?
Also, what is the obstacle of directly using flash attention repo with the token-level memory management?
@KexinFeng We try to implement this triton kernel that use flash attention。but currently it is not fast enough。
@hiworldwzj I'm actually exploring the same thing. It looks to me that, in principle, flash attention is completely orthogonal to the token-wise memory management (flash attention is in essence a streaming way of attention computation, which on the paper is naturally compatible with the token-wise memory management), so the acceleration effect is supposed to be directly added on top it. It is a little surprising that "it is not fast enough".
vllm is actually working on exactly the same thing: https://github.com/vllm-project/vllm/pull/877#issuecomment-1697079089
I'm wondering if you have any work-in-progress implementation to share, so that the community can contribute?