flash-attention
flash-attention copied to clipboard
support attentions in AlphaFold2
We added the support of (additive) attention_mask and (additive) attention_bias, so that the flash-attention could be used in Evoforomer in Alphafold2. We benchmarked in Uni-Fold, and it achieved a further ~20% speed-up.
Comments and suggestions are very welcome!
some benchmark results:
Training GPU hours:

Inference time and memory cost (one evoformer layer, without chunking):

Currently, we implemented the following case for attention bias/mask,
Support the shape of q/k/v as follow:
q's shape [total_size * head, seq_q, head_dim]
k's shape [total_size * head, seq_k, head_dim]
v's shape [total_size * head, seq_k, head_dim]
Attention Mask
[total_size, head, seq_q, seq_k]
1. total_size must be the same as q's total_size
2. head must be 1 or head like shape in q
3. seq_q must be 1
4. seq_k must be the same as k's seq_k
Attention Bias
[total_size, head, seq_q, seq_k]
1. total_size must be 1
2. head must be the same as q's head
3. seq_q must be the same as q's seq_q
4. seq_k must be the same as k's seq_k
Thanks so much for the great work, and congrats on the speedup on Uni-Fold!
I'll have more time this weekend to review carefully.
Thanks so much for the great work, and congrats on the speedup on Uni-Fold!
I'll have more time this weekend to review carefully.
Great, any suggestions are welcomed. we still have something that needs to refine to make it more applicable.
- Fixing the interface incompatible in
flash_attn_interface.py - Adding our unit test for the mask and bias interface.
- Adding the odd length of mask/bias in the last shape.
Not worked if mask or bias have odd sequence length. CUDA error (/tmp/pip-req-build-k5fpgkes/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu:140): misaligned address
@guolinke @robotcator Do we need both mask & bias, or would a single bias suffice? I think that could simplify the code & reduce compilation time.
Attention Mask [total_size, head, seq_q, seq_k]
- total_size must be the same as q's total_size
- head must be 1 or head like shape in q
- seq_q must be 1
- seq_k must be the same as k's seq_k
From the shape given my understanding is that the mask is a key-padding mask. Does that change across different layers for the same batch?
If the key-padding mask doesn't change across layer then the most performant way to do it is to remove padding before the first layer (we have a function unpad_input), run through all the layers, then optionally add back the padding tokens.
Is my understanding correct?
Not worked if mask or bias have odd sequence length.
CUDA error (/tmp/pip-req-build-k5fpgkes/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu:140): misaligned address
Thank you for your advice. Currently, Adding the odd length of mask/bias in the last shape is in our progress list.
@guolinke @robotcator Do we need both mask & bias, or would a single bias suffice? I think that could simplify the code & reduce compilation time.
Attention Mask [total_size, head, seq_q, seq_k]
- total_size must be the same as q's total_size
- head must be 1 or head like shape in q
- seq_q must be 1
- seq_k must be the same as k's seq_k
From the shape given my understanding is that the mask is a key-padding mask. Does that change across different layers for the same batch? If the key-padding mask doesn't change across layer then the most performant way to do it is to remove padding before the first layer (we have a function
unpad_input), run through all the layers, then optionally add back the padding tokens. Is my understanding correct?
Thanks for the suggestion @tridao . the flatten-non-padding input is not trivial in alphafold2.
- there are 2 representations (token-level and pair-level), and 4 kinds of attention, in which the mask/bias Evoformer are different.
- the 2 representations are communicated at each Evoformer layer, and the shape is better in the padding form for the computation.
the flatten-non-padding input is not trivial in alphafold2.
I see, thanks for explaining, this is very helpful. How about we pass in a tensor (type int) with the sequence lengths of the key for each batch? That might be faster (we read 1 int instead of one vector of mask) and simpler (reduce code complexity and compilation time). Would this work for the alphafold2 use case?
If this sounds reasonable I'll take a stab at implementing the seqlen_k masking and then rebase and merge the bias part from this PR?
Another way to phrase this question: is the mask for each sequence always of the form [0, 0, ..., 0, -inf, -inf ...]? Or could they have the form [0, -inf, 0, ..., -inf, 0]? That is, are the masked keys always at the end of the sequence?
Another way to phrase this question: is the mask for each sequence always of the form [0, 0, ..., 0, -inf, -inf ...]? Or could they have the form [0, -inf, 0, ..., -inf, 0]? That is, are the masked keys always at the end of the sequence?
@tridao Hi, Tridao, sorry for the late reply. Using the 'key padding mask' style is a really good method to reduce code complexity and compilation time. But we checked that the masked keys were not always at the end of the sequence.
One case is that the gen_msa_attn_mask function in here will generate two types of mask, i.e. row_mask, col_mask.
The row_mask was generated from the original mas_mask tensor and the col_mask was generated from the transpose of the mas_mask tensor. So the col_mask tensor's masked keys were not at the end of the sequence. The minimal example is as follows.
# the original `mas_mask` tensor.
tensor([[0., 0., 0., -inf, -inf, -inf],
[0., 0., 0., 0., -inf, -inf]])
# the transpose of the `mas_mask` tensor.
tensor([[0., 0.],
[0., 0.],
[0., 0.],
[-inf, 0.],
[-inf, -inf],
[-inf, -inf]])
Another case is that the mas_mask's masked keys were not always padding at end of the sequence, there will be at any position in the sequence.
So we choose to use the attention mask rather than the key padding mask style method. If you have any confusion, please free to contact us. We also suffer from the compilation time problem, hope we can find some method to tackle it.
@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?
@tridao Any update on merging this, or the part to support arbitrary masks and biases?
I just haven't had time to review and merge it (it's a pretty big change). Still trying to figure out a good way to support both mask and bias without increasing compilation time by 4x.
@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?
Do you mean overflow or nan? And can you provide some shapes of inputs?
@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?
Do you mean overflow or nan? And can you provide some shapes of inputs?
The model is training with FP16. With FP16 training, the loss may explode, we progressively lower the dynamic loss scale until it reaches the minimum value. If the attn_bias is not None, the loss scale will quickly reach the minimum value at the beginning, like this:

The code snippet is shown below (follow https://github.com/dptech-corp/flash-attention/blob/main/flash_attn/attention.py):
def attention(q, k, v, attn_bias, seq_len)
# q (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
# k (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
# v (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
# attn_bias (bsz, num_heads, seq_len, seq_len) = (128, 12, seq_len, seq_len)
cu_seqlens = torch.arange(
0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=q.device
)
attn = flash_attn_unpadded_func(
q, k, v, cu_seqlens, cu_seqlens, seq_len, seq_len,
attn_mask=None, attn_bias=attn_bias,
dropout_p=0.0,
softmax_scale=1.0, causal=False
)
@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?
Do you mean overflow or nan? And can you provide some shapes of inputs?
The model is training with FP16. With FP16 training, the loss may explode, we progressively lower the dynamic loss scale until it reaches the minimum value. If the
attn_biasis not None, the loss scale will quickly reach the minimum value at the beginning, like this:The code snippet is shown below (follow https://github.com/dptech-corp/flash-attention/blob/main/flash_attn/attention.py):
def attention(q, k, v, attn_bias, seq_len) # q (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64) # k (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64) # v (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64) # attn_bias (bsz, num_heads, seq_len, seq_len) = (128, 12, seq_len, seq_len) cu_seqlens = torch.arange( 0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=q.device ) attn = flash_attn_unpadded_func( q, k, v, cu_seqlens, cu_seqlens, seq_len, seq_len, attn_mask=None, attn_bias=attn_bias, dropout_p=0.0, softmax_scale=1.0, causal=False )
It seems that not trivial to figure out. Here are some ideas from my view. 1). whether the half-precision is overflow due to the limited representation range. 2). the attention bias & mask is not as trivial as the PyTorch version. The broadcast mechanism is very flexible in PyTorch but it needs more effort to implement when combining all operations into one. We implemented a limited shape to fit our model, it's not generalized to all models. So please check out this carefully in the supported list.
@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?
Do you mean overflow or nan? And can you provide some shapes of inputs?
The model is training with FP16. With FP16 training, the loss may explode, we progressively lower the dynamic loss scale until it reaches the minimum value. If the
attn_biasis not None, the loss scale will quickly reach the minimum value at the beginning, like this:The code snippet is shown below (follow https://github.com/dptech-corp/flash-attention/blob/main/flash_attn/attention.py):
def attention(q, k, v, attn_bias, seq_len) # q (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64) # k (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64) # v (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64) # attn_bias (bsz, num_heads, seq_len, seq_len) = (128, 12, seq_len, seq_len) cu_seqlens = torch.arange( 0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=q.device ) attn = flash_attn_unpadded_func( q, k, v, cu_seqlens, cu_seqlens, seq_len, seq_len, attn_mask=None, attn_bias=attn_bias, dropout_p=0.0, softmax_scale=1.0, causal=False )It seems that not trivial to figure out. Here are some ideas from my view. 1). whether the half-precision is overflow due to the limited representation range. 2). the attention bias & mask is not as trivial as the PyTorch version. The broadcast mechanism is very flexible in PyTorch but it needs more effort to implement when combining all operations into one. We implemented a limited shape to fit our model, it's not generalized to all models. So please check out this carefully in the supported list.
Thank you for the reply.
Hi, thanks everyone for bringing up this enhancement! Is this PR a way to support custom attention masks? Is this the best walkaround so far, given it is not officially supported yet?
Hi, thanks everyone for bringing up this enhancement! Is this PR a way to support custom attention masks? Is this the best walkaround so far, given it is not officially supported yet?
For the padding mask, I think the official repo is already supported. For custom attention mask, we also support some shapes but not for all.
@robotcator I have a question about attn_bias, if my attn_bias is trainable, does flash attn will compute grad of attn_bias automatically ?
Hello, what's up with this PR? Is the code in a usable state? I didn't quite get it from the above discussion. Thanks for you work, awesome job!
@robotcator I have a question about
attn_bias, if myattn_biasis trainable, does flash attn will compute grad ofattn_biasautomatically ?
I don't know whether it's too late to reply, actually, the attn_bias's grad of will compute automatically.
Guys, let's face it. It's like there is a hidden force not allowing this one to go through. Someone is gatekeeping
For anyone still looking for this see: https://pytorch.org/blog/flexattention/