flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Any plans to support tree attention mask?

Open KexinFeng opened this issue 10 months ago • 7 comments

Tree attention mask is already supported in huggingface/transformers: https://github.com/huggingface/transformers/pull/27539 It will be very helpful for the speculative decoding applications. More sepcifically, in flash_attn/flash_attn_interface.py#flash_attn_with_kvcache, the tree attention mask will need to be specified and passed in as an argument.

Do you have any near plans to support it?

Thanks

Related questions: https://github.com/Dao-AILab/flash-attention/issues/840, https://github.com/Dao-AILab/flash-attention/issues/918

KexinFeng avatar Apr 21 '24 04:04 KexinFeng

Sure, we'll just need someone to contribute :D

tridao avatar Apr 21 '24 04:04 tridao

I'm keen to try supporting a generic mask case, like [B, Q, K] bool, and doing conditional execution. Ideally this covers quite a lot of masking cases, but I guess optimised kernels would work better for more structured masks (like Tree).

thorinf avatar Apr 29 '24 12:04 thorinf

I don't see much difference between a generic mask and a structured mask. For a tree mask, the mask argument would also be of [B, K, Q]. In the 4d attention mask mentioned above, it's nothing but [b, h, k, q] h being number of head.

If you are able to implement a generic mask, then a structured mask will be ready

KexinFeng avatar May 06 '24 16:05 KexinFeng

What I mean is that for a structured mask you don't necessarily have to create a bool tensor. In the casual case it can be hardcoded in the kernel to ignore j>i+k_cache, which saves a little bit of memory. If its structured the locations you'll visit are predictable.

thorinf avatar May 06 '24 18:05 thorinf

I see. Yeah, in the causal mask case, indeed the bool tensor mask argument is not required. For the tree attention mask, however, this argument will be inevitable. But probably this doesn't increase much implementation complexity, since the causal mask will internally be converted to such tensor anyway. @thorinf Look forward to your PR!

KexinFeng avatar May 07 '24 16:05 KexinFeng

Hello, sorry for the naive question but:

  1. Why do you need structured masking? can't you do something similar with attention biases?
  2. Are you hoping that you might be able to skip blocks that are entirely masked? or will you still compute attention over the full matrix?

It might help me understand this a bit more :)

jkobject avatar May 23 '24 09:05 jkobject

  1. Why do you need structured masking? can't you do something similar with attention biases?

please check out this blogpost with 4D masks description https://huggingface.co/blog/poedator/4d-masks

poedator avatar Aug 06 '24 21:08 poedator