flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

[v2] Attention Masking

Open MikeynJerry opened this issue 1 year ago • 18 comments

Is any plan to add attention masking support? PyTorch's version of flash attention v1 included the ability to provide an attention mask in their implementation and it would be very useful to have this feature in v2.

MikeynJerry avatar Jul 20 '23 01:07 MikeynJerry

In fact, when you send an attention mask to PyTorch's implementation, flash attention didn't work.

leizhao1234 avatar Jul 20 '23 02:07 leizhao1234

Yes, facing the same issue. @tridao Can you please take a look at this and respond when you are available?

balachandarsv avatar Jul 21 '23 07:07 balachandarsv

Attention mask isn't supported (either in v1 or v2). I might implement it at some point but there are other priorities now.

tridao avatar Jul 21 '23 07:07 tridao

I thought masking is supported through flash_attn_varlen_func

https://github.com/Dao-AILab/flash-attention/blob/d30f2e1cd50185c98ed88c0684b4a603f15bee37/flash_attn/flash_attn_interface.py#L454C21-L454C21

PeterL1n avatar Aug 03 '23 21:08 PeterL1n

I have tested v1.0.7 and v2.0.4. The result turns out that none of them supports attention mask ---

  • A: using flash attention with attention mask
  • B: not using flash attention, with attention mask

The results of A and B are different.

zhipeng93 avatar Sep 15 '23 02:09 zhipeng93

This paper might be relevant: https://arxiv.org/abs/2306.01160.

There are several related issues:

  • https://github.com/Dao-AILab/flash-attention/issues/530
  • https://github.com/Dao-AILab/flash-attention/issues/506
  • https://github.com/Dao-AILab/flash-attention/issues/424
  • https://github.com/Dao-AILab/flash-attention/issues/342
  • https://github.com/Dao-AILab/flash-attention/issues/307
  • https://github.com/Dao-AILab/flash-attention/issues/242
  • https://github.com/Dao-AILab/flash-attention/issues/127
  • https://github.com/Dao-AILab/flash-attention/issues/119
  • https://github.com/Dao-AILab/flash-attention/issues/17

I believe pytorch 2.1 will have a memory efficient attention implementation that supports arbitrary masks: https://github.com/pytorch/pytorch/issues/96099

samvanstroud avatar Sep 25 '23 21:09 samvanstroud

@tridao Hello, I plan to add a bias mask in flashattention2. I noticed that in order to integrate the scale and add operations scale_apply_exp2 ,the scale is delayed until after the maximum value is calculated. I plan to support bias mask in the apply_mask_causal function, I think if a bias mask is supported, it seems that ffma optimization in scale_apply_exp2 can be cancelled. Using scale and bias can still benefit from FFMA, do you have any suggestions?

defei-coder avatar Oct 20 '23 12:10 defei-coder

flash_attn/flash_attn_triton.py support bias input you can use bias=-inf

zhangyipin avatar Nov 06 '23 07:11 zhangyipin

flash_attn/flash_attn_triton.py support bias input you can use bias=-inf

This is a good point but the example itself is not working with pytorch2.0+ (<==triton2.0+) 😭

wehos avatar Feb 29 '24 18:02 wehos

Anyone have tips on custom masks with flash attention for training?

(I need this to train encoder-decoder models with variable-length sequences using non-causal masks.)

This came up in a recent article: https://www.yitay.net/blog/training-great-llms-entirely-from-ground-zero-in-the-wilderness

The other striking thing is how little support these codebases have for large scale encoder-decoder training or even prefixLM training. To that end, even flash attention has consistently declined to provide support for prefixLM training (i.e., custom masks) despite reasonable demand on their github issues for whatever reason.

Curious what this would take or if it is still out of scope for the flash attention library?

Really grateful that this exists!! Just posting for visibility in case others have solved this problem :)

jaanli avatar Mar 06 '24 23:03 jaanli

Curious what this would take or if it is still out of scope for the flash attention library?

Not out of scope, it's just someone needs to go implement it :D

tridao avatar Mar 07 '24 00:03 tridao

Understood — thank you!! Will try using the varlen functions for now :)

jaanli avatar Mar 07 '24 03:03 jaanli

I was wondering if there has been any updates on this? AlphaFold3 uses a lot of attention pair biasing and it would be tremendously useful to computational biology if flash attention supported attention biasing!

ardagoreci avatar May 26 '24 17:05 ardagoreci

I was wondering if there has been any updates on this? AlphaFold3 uses a lot of attention pair biasing and it would be tremendously useful to computational biology if flash attention supported attention biasing!

Right, we still need someone to implement it.

tridao avatar May 26 '24 18:05 tridao

@tridao Was wondering, what needs to be done for this to be implemented (I'm assuming efficiently? otw it seems quite simple)

I need a similar feature (arbitrary attention masks) but I figured I might take a stab at just implementing it if it still needs to be done.

alexzhang13 avatar Jul 13 '24 19:07 alexzhang13

I've implemented a version of custom masking for FA2 in Triton: https://github.com/alexzhang13/flashattention2-custom-mask

It suffices for my use case, but if something comes up where it's necessary to touch the FA3 code I may re-visit this.

alexzhang13 avatar Jul 21 '24 07:07 alexzhang13

Attention mask isn't supported (either in v1 or v2). I might implement it at some point but there are other priorities now.

Seems like the FlashAttention class does take in a key_padding_mask argument in its forward method. What would be the difference between this and the attention mask to be implemented? Cc @tridao. Thanks!

amyxlu avatar Aug 21 '24 22:08 amyxlu

As you can see in the code, key_padding_mask just removes elements from keys and values before passing to the flash attention kernel. There's no attention mask passed to the kernel.

tridao avatar Aug 21 '24 22:08 tridao

As you can see in the code, key_padding_mask just removes elements from keys and values before passing to the flash attention kernel. There's no attention mask passed to the kernel.

Is there any plan to support key_padding_mask in MHA in v2 ? My understanding is that this was supported in v1 (in flash_attn.flash_attention.FlashMHA), but in v2, one can only use key_padding_mask when use_flash_attn is False (in flash_attn.modules.mha.MHA). Thank you.

krejciadam avatar Oct 16 '24 08:10 krejciadam

Hi Everyone, Recently I published a paper in ENLSP Workshop@NEURips 2024, to address this problem, the paper can be found here: https://arxiv.org/pdf/2409.15097

I have the code, but its in a private repository currently, as I am still cleaning up the code. If someone wants to access this repo just send a mail to: [email protected]

Meanwhile, I realised that pytorch team already implemented a change which pretty much uses same method which I used. (I came up with my method independently for a university project the pytorch blog came around half a month after my university project).

Anyways, TL:DR - pytorch has now enabled custom masking of flash attention, you can find it here: https://pytorch.org/blog/flexattention/ (And, I am sad man, as my method will never be used)

agshar96 avatar Oct 18 '24 07:10 agshar96