flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

support attentions in AlphaFold2

Open guolinke opened this issue 3 years ago • 24 comments

We added the support of (additive) attention_mask and (additive) attention_bias, so that the flash-attention could be used in Evoforomer in Alphafold2. We benchmarked in Uni-Fold, and it achieved a further ~20% speed-up.

Comments and suggestions are very welcome!

some benchmark results:

Training GPU hours: img_v2_14099b19-6c86-42a0-ae3b-25630fe4fbfg

Inference time and memory cost (one evoformer layer, without chunking): image

guolinke avatar Oct 13 '22 01:10 guolinke

Currently, we implemented the following case for attention bias/mask,

Support the shape of q/k/v as follow:
q's shape [total_size * head, seq_q, head_dim]
k's shape [total_size * head, seq_k, head_dim]
v's shape [total_size * head, seq_k, head_dim]
Attention Mask 
[total_size, head, seq_q, seq_k]
1. total_size must be the same as q's total_size
2. head must be 1 or head like shape in q
3. seq_q must be 1  
4. seq_k must be the same as k's seq_k 
Attention Bias
[total_size, head, seq_q, seq_k]
1. total_size must be 1
2. head must be the same as q's head
3. seq_q must be the same as q's seq_q
4. seq_k must be the same as k's seq_k

robotcator avatar Oct 19 '22 03:10 robotcator

Thanks so much for the great work, and congrats on the speedup on Uni-Fold!

I'll have more time this weekend to review carefully.

tridao avatar Oct 19 '22 06:10 tridao

Thanks so much for the great work, and congrats on the speedup on Uni-Fold!

I'll have more time this weekend to review carefully.

Great, any suggestions are welcomed. we still have something that needs to refine to make it more applicable.

  1. Fixing the interface incompatible in flash_attn_interface.py
  2. Adding our unit test for the mask and bias interface.
  3. Adding the odd length of mask/bias in the last shape.

robotcator avatar Oct 19 '22 07:10 robotcator

Not worked if mask or bias have odd sequence length. CUDA error (/tmp/pip-req-build-k5fpgkes/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu:140): misaligned address

reymondzzzz avatar Nov 03 '22 14:11 reymondzzzz

@guolinke @robotcator Do we need both mask & bias, or would a single bias suffice? I think that could simplify the code & reduce compilation time.

Attention Mask [total_size, head, seq_q, seq_k]

  1. total_size must be the same as q's total_size
  2. head must be 1 or head like shape in q
  3. seq_q must be 1
  4. seq_k must be the same as k's seq_k

From the shape given my understanding is that the mask is a key-padding mask. Does that change across different layers for the same batch? If the key-padding mask doesn't change across layer then the most performant way to do it is to remove padding before the first layer (we have a function unpad_input), run through all the layers, then optionally add back the padding tokens. Is my understanding correct?

tridao avatar Nov 06 '22 19:11 tridao

Not worked if mask or bias have odd sequence length. CUDA error (/tmp/pip-req-build-k5fpgkes/csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu:140): misaligned address

Thank you for your advice. Currently, Adding the odd length of mask/bias in the last shape is in our progress list.

robotcator avatar Nov 07 '22 05:11 robotcator

@guolinke @robotcator Do we need both mask & bias, or would a single bias suffice? I think that could simplify the code & reduce compilation time.

Attention Mask [total_size, head, seq_q, seq_k]

  1. total_size must be the same as q's total_size
  2. head must be 1 or head like shape in q
  3. seq_q must be 1
  4. seq_k must be the same as k's seq_k

From the shape given my understanding is that the mask is a key-padding mask. Does that change across different layers for the same batch? If the key-padding mask doesn't change across layer then the most performant way to do it is to remove padding before the first layer (we have a function unpad_input), run through all the layers, then optionally add back the padding tokens. Is my understanding correct?

Thanks for the suggestion @tridao . the flatten-non-padding input is not trivial in alphafold2.

  1. there are 2 representations (token-level and pair-level), and 4 kinds of attention, in which the mask/bias Evoformer are different.
  2. the 2 representations are communicated at each Evoformer layer, and the shape is better in the padding form for the computation.

guolinke avatar Nov 07 '22 08:11 guolinke

the flatten-non-padding input is not trivial in alphafold2.

I see, thanks for explaining, this is very helpful. How about we pass in a tensor (type int) with the sequence lengths of the key for each batch? That might be faster (we read 1 int instead of one vector of mask) and simpler (reduce code complexity and compilation time). Would this work for the alphafold2 use case?

If this sounds reasonable I'll take a stab at implementing the seqlen_k masking and then rebase and merge the bias part from this PR?

tridao avatar Nov 07 '22 18:11 tridao

Another way to phrase this question: is the mask for each sequence always of the form [0, 0, ..., 0, -inf, -inf ...]? Or could they have the form [0, -inf, 0, ..., -inf, 0]? That is, are the masked keys always at the end of the sequence?

tridao avatar Nov 07 '22 20:11 tridao

Another way to phrase this question: is the mask for each sequence always of the form [0, 0, ..., 0, -inf, -inf ...]? Or could they have the form [0, -inf, 0, ..., -inf, 0]? That is, are the masked keys always at the end of the sequence?

@tridao Hi, Tridao, sorry for the late reply. Using the 'key padding mask' style is a really good method to reduce code complexity and compilation time. But we checked that the masked keys were not always at the end of the sequence.

One case is that the gen_msa_attn_mask function in here will generate two types of mask, i.e. row_mask, col_mask.

The row_mask was generated from the original mas_mask tensor and the col_mask was generated from the transpose of the mas_mask tensor. So the col_mask tensor's masked keys were not at the end of the sequence. The minimal example is as follows.

# the original  `mas_mask` tensor.
tensor([[0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf]])

# the transpose of the `mas_mask` tensor.
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [-inf, 0.],
        [-inf, -inf],
        [-inf, -inf]])

Another case is that the mas_mask's masked keys were not always padding at end of the sequence, there will be at any position in the sequence.

So we choose to use the attention mask rather than the key padding mask style method. If you have any confusion, please free to contact us. We also suffer from the compilation time problem, hope we can find some method to tackle it.

robotcator avatar Nov 09 '22 04:11 robotcator

@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?

logicwong avatar Jan 02 '23 09:01 logicwong

@tridao Any update on merging this, or the part to support arbitrary masks and biases?

rahul003 avatar Jan 04 '23 18:01 rahul003

I just haven't had time to review and merge it (it's a pretty big change). Still trying to figure out a good way to support both mask and bias without increasing compilation time by 4x.

tridao avatar Jan 04 '23 19:01 tridao

@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?

Do you mean overflow or nan? And can you provide some shapes of inputs?

robotcator avatar Jan 06 '23 14:01 robotcator

@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?

Do you mean overflow or nan? And can you provide some shapes of inputs?

The model is training with FP16. With FP16 training, the loss may explode, we progressively lower the dynamic loss scale until it reaches the minimum value. If the attn_bias is not None, the loss scale will quickly reach the minimum value at the beginning, like this: 215360A7-4B8F-4D45-8B47-7E80F7553922

The code snippet is shown below (follow https://github.com/dptech-corp/flash-attention/blob/main/flash_attn/attention.py):

def attention(q, k, v, attn_bias, seq_len)
	# q (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# k (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# v (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# attn_bias (bsz, num_heads, seq_len, seq_len) = (128, 12, seq_len, seq_len)

	cu_seqlens = torch.arange(
	    0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=q.device
	)
	attn = flash_attn_unpadded_func(
	    q, k, v, cu_seqlens, cu_seqlens, seq_len, seq_len,
	    attn_mask=None, attn_bias=attn_bias,
	    dropout_p=0.0,
	    softmax_scale=1.0, causal=False
	)

logicwong avatar Jan 07 '23 03:01 logicwong

@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?

Do you mean overflow or nan? And can you provide some shapes of inputs?

The model is training with FP16. With FP16 training, the loss may explode, we progressively lower the dynamic loss scale until it reaches the minimum value. If the attn_bias is not None, the loss scale will quickly reach the minimum value at the beginning, like this: 215360A7-4B8F-4D45-8B47-7E80F7553922

The code snippet is shown below (follow https://github.com/dptech-corp/flash-attention/blob/main/flash_attn/attention.py):

def attention(q, k, v, attn_bias, seq_len)
	# q (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# k (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# v (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# attn_bias (bsz, num_heads, seq_len, seq_len) = (128, 12, seq_len, seq_len)

	cu_seqlens = torch.arange(
	    0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=q.device
	)
	attn = flash_attn_unpadded_func(
	    q, k, v, cu_seqlens, cu_seqlens, seq_len, seq_len,
	    attn_mask=None, attn_bias=attn_bias,
	    dropout_p=0.0,
	    softmax_scale=1.0, causal=False
	)

It seems that not trivial to figure out. Here are some ideas from my view. 1). whether the half-precision is overflow due to the limited representation range. 2). the attention bias & mask is not as trivial as the PyTorch version. The broadcast mechanism is very flexible in PyTorch but it needs more effort to implement when combining all operations into one. We implemented a limited shape to fit our model, it's not generalized to all models. So please check out this carefully in the supported list.

robotcator avatar Jan 09 '23 05:01 robotcator

@robotcator I encounter gradient overflow when attn_mask is not None or attn_bias is not None. Could you give me some advice?

Do you mean overflow or nan? And can you provide some shapes of inputs?

The model is training with FP16. With FP16 training, the loss may explode, we progressively lower the dynamic loss scale until it reaches the minimum value. If the attn_bias is not None, the loss scale will quickly reach the minimum value at the beginning, like this: 215360A7-4B8F-4D45-8B47-7E80F7553922 The code snippet is shown below (follow https://github.com/dptech-corp/flash-attention/blob/main/flash_attn/attention.py):

def attention(q, k, v, attn_bias, seq_len)
	# q (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# k (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# v (bsz * seq_len, num_heads, dim) = (128 * seq_len, 12, 64)
	# attn_bias (bsz, num_heads, seq_len, seq_len) = (128, 12, seq_len, seq_len)

	cu_seqlens = torch.arange(
	    0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=q.device
	)
	attn = flash_attn_unpadded_func(
	    q, k, v, cu_seqlens, cu_seqlens, seq_len, seq_len,
	    attn_mask=None, attn_bias=attn_bias,
	    dropout_p=0.0,
	    softmax_scale=1.0, causal=False
	)

It seems that not trivial to figure out. Here are some ideas from my view. 1). whether the half-precision is overflow due to the limited representation range. 2). the attention bias & mask is not as trivial as the PyTorch version. The broadcast mechanism is very flexible in PyTorch but it needs more effort to implement when combining all operations into one. We implemented a limited shape to fit our model, it's not generalized to all models. So please check out this carefully in the supported list.

Thank you for the reply.

logicwong avatar Jan 10 '23 09:01 logicwong

Hi, thanks everyone for bringing up this enhancement! Is this PR a way to support custom attention masks? Is this the best walkaround so far, given it is not officially supported yet?

subercui avatar Apr 09 '23 19:04 subercui

Hi, thanks everyone for bringing up this enhancement! Is this PR a way to support custom attention masks? Is this the best walkaround so far, given it is not officially supported yet?

For the padding mask, I think the official repo is already supported. For custom attention mask, we also support some shapes but not for all.

robotcator avatar May 11 '23 06:05 robotcator

@robotcator I have a question about attn_bias, if my attn_bias is trainable, does flash attn will compute grad of attn_bias automatically ?

Birdylx avatar Aug 15 '23 17:08 Birdylx

Hello, what's up with this PR? Is the code in a usable state? I didn't quite get it from the above discussion. Thanks for you work, awesome job!

nikita-petrashen avatar Feb 08 '24 12:02 nikita-petrashen

@robotcator I have a question about attn_bias, if my attn_bias is trainable, does flash attn will compute grad of attn_bias automatically ?

I don't know whether it's too late to reply, actually, the attn_bias's grad of will compute automatically.

robotcator avatar Mar 06 '24 07:03 robotcator

Guys, let's face it. It's like there is a hidden force not allowing this one to go through. Someone is gatekeeping

nofreewill42 avatar Mar 12 '24 12:03 nofreewill42

For anyone still looking for this see: https://pytorch.org/blog/flexattention/

maxall41 avatar Sep 14 '24 03:09 maxall41