flash-attention
flash-attention copied to clipboard
Most effective way of applying key_padding_mask
Hi, I am trying to integrate flash-attention into the model I am looking at.
My model uses key_padding_mask
to support variable size of samples in a batch during finetuning.
I found out that flash attention has flash_attn_varlen_kvpacked_func
implementation, and there are some additional infos that need to be passed to this function, so I implemented as below.
q_unpad = rearrange(q, "b s h d -> (b s) h d")
cu_seqlens_q = torch.arange(0, (B + 1) * M, step=M, dtype=torch.int32, device=device)
max_seqlen_q = M
output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=B)
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
kv_unpad = torch.stack([k_unpad, v_unpad], dim=2)
out_unpad = flash_attn_varlen_kvpacked_func(
q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, return_attn_probs=False, causal=False
)
out = output_pad_fn(out_unpad)
However, this implementation is showing much slower result than just using F.scaled_dot_product_attention(q, k, v, attn_mask=key_padding_mask.unsqueeze(1).unsqueeze(1), dropout_p=dropout_p)
Am I using flash-attention correctly with key padding mask???