flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Most effective way of applying key_padding_mask

Open YJYJLee opened this issue 9 months ago • 1 comments

Hi, I am trying to integrate flash-attention into the model I am looking at. My model uses key_padding_mask to support variable size of samples in a batch during finetuning. I found out that flash attention has flash_attn_varlen_kvpacked_func implementation, and there are some additional infos that need to be passed to this function, so I implemented as below.

 q_unpad = rearrange(q, "b s h d -> (b s) h d")
        cu_seqlens_q = torch.arange(0, (B + 1) * M, step=M, dtype=torch.int32, device=device)
        max_seqlen_q = M
        output_pad_fn = lambda output_unpad: rearrange(output_unpad, "(b s) h d -> b s h d", b=B)
        k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
        v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
        kv_unpad = torch.stack([k_unpad, v_unpad], dim=2)
        out_unpad = flash_attn_varlen_kvpacked_func(
            q_unpad, kv_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, return_attn_probs=False, causal=False
        )
        out = output_pad_fn(out_unpad)

However, this implementation is showing much slower result than just using F.scaled_dot_product_attention(q, k, v, attn_mask=key_padding_mask.unsqueeze(1).unsqueeze(1), dropout_p=dropout_p)

Am I using flash-attention correctly with key padding mask???

YJYJLee avatar Sep 26 '23 16:09 YJYJLee