flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Are there any plans for supporting an explicit attention mask?

Open Avelina9X opened this issue 1 year ago • 5 comments

I've noticed that the Triton implementation supports explicit attention bias, which can be used to support arbitrary mask shapes with large negative values, however is there any planned support for explicit (boolean) masks in the CUDA implementation?

I've noticed some requests for features like off-diagonal attention, but an explicit attention mask would be able to facilitate this and any other arbitrary masking scheme - such as XL-Net, attention sinking, landmark attention - without needing to hardcode the attention scheme and enable it with an argument or seperate python interface.

Avelina9X avatar Feb 20 '24 16:02 Avelina9X

It seems like the PyTorch attention implementation supports custom attention masks and also uses Flash-Attention 2: https://twitter.com/StasBekman/status/1736083447658225665. Though I'm not sure that passing in an attention mask doesn't cause the op to dispatch to a non-FA2 kernel.

normster avatar Mar 05 '24 21:03 normster

If there's attn mask pytorch does not dispatch to FA2 kernel, rather the kernel from xformers.

tridao avatar Mar 05 '24 21:03 tridao

Thanks for the info @tridao! Is support for arbitrary attention masks on your roadmap? This would be incredibly useful for some encoder-decoder and prefixLM models. Mandatory thank you for your amazing work!

abdulfatir avatar Mar 29 '24 01:03 abdulfatir

If there's attn mask pytorch does not dispatch to FA2 kernel, rather the kernel from xformers.

Thanks for this valuable tip. No wonder torch.nn.functional.scaled_dot_product_attention does not bring any speed up in my case

xiabingquan avatar Sep 10 '24 06:09 xiabingquan

I'm looking for bias mask support too, in FA2 and better FA3. Is there a roadmap for this? Thank you~

lin-ht avatar Sep 10 '24 17:09 lin-ht