flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Avoid padding computation with `cu_seqlens`

Open imoneoi opened this issue 1 year ago • 3 comments

To work with torch.compile which is more efficient on static shapes, I pad some tokens at the end to make the shape of q,k,v static, e.g. [N, D].

Can I set the last element in cu_seqlens of varlen API to be less than N to avoid computing the padding? Also, is the backward pass accurate in this case?

imoneoi avatar Sep 14 '24 14:09 imoneoi

Yes I think that should work. You should test that still

tridao avatar Sep 14 '24 17:09 tridao

Thanks! I have tested the kernel and it does work. However, the padding elements may be uninitialized, resulting in NaN/inf in the forward and backward passes. Can we include a fix to simply zero these elements?

imoneoi avatar Sep 15 '24 01:09 imoneoi

BTW, here is the code used for testing:

from typing import Any
import torch

from tqdm import tqdm
from flash_attn import flash_attn_varlen_func


def test_flash_attn_padding(
    seed: int = 0,
    test_rounds: int = 10,
    num_heads: int = 8,
    head_size: int = 64,
    seq_len: int = 160,
    batch_size: int = 131_072,

    dtype: Any = torch.bfloat16
):
    torch.manual_seed(seed)
    torch.set_default_device("cuda")

    # Construct testdata
    q = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
    k = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
    v = torch.randn((batch_size, num_heads, head_size), dtype=dtype)

    seqlens = torch.cat([
        torch.full((batch_size // seq_len, ), seq_len, dtype=torch.int32),
        torch.full((1, ), batch_size % seq_len, dtype=torch.int32)
    ])

    cu_seqlens = torch.nn.functional.pad(seqlens.cumsum(-1, dtype=seqlens.dtype), (1, 0))
    max_seqlen = seqlens.max()

    # Multiple rounds so that torch.empty() might be filled with random value
    for round in tqdm(range(test_rounds)):
        # Fwd
        gt_out    = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens,      cu_seqlens_k=cu_seqlens,      max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)
        nopad_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens[:-1], cu_seqlens_k=cu_seqlens[:-1], max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)

        assert torch.allclose(nopad_out[:cu_seqlens[-2]], gt_out[:cu_seqlens[-2]])
        # assert torch.allclose(nopad_out[cu_seqlens[-2]:], torch.zeros((), dtype=dtype))  # Might be empty here

        # Bwd
        # ground truth
        dgrad = torch.randn((batch_size, num_heads, head_size), dtype=dtype)
        q.requires_grad_()
        k.requires_grad_()
        v.requires_grad_()

        gt_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens,      cu_seqlens_k=cu_seqlens,      max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen)
        (dgrad * gt_out).sum().backward()

        gt_dq = q.grad
        gt_dk = k.grad
        gt_dv = v.grad
        q.grad = None
        k.grad = None
        v.grad = None

        # unpadded
        nopad_out = flash_attn_varlen_func(q, k, v, cu_seqlens_q=cu_seqlens[:-1], cu_seqlens_k=cu_seqlens[:-1], max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen) 
        (dgrad * nopad_out).sum().backward()

        assert torch.allclose(q.grad[:cu_seqlens[-2]], gt_dq[:cu_seqlens[-2]])
        # assert torch.allclose(q.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype))  # Might be empty here

        assert torch.allclose(k.grad[:cu_seqlens[-2]], gt_dk[:cu_seqlens[-2]])
        # assert torch.allclose(k.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype))  # Might be empty here

        assert torch.allclose(v.grad[:cu_seqlens[-2]], gt_dv[:cu_seqlens[-2]])
        # assert torch.allclose(v.grad[cu_seqlens[-2]:], torch.zeros((), dtype=dtype))  # Might be empty here

        q.grad = None
        k.grad = None
        v.grad = None

if __name__ == "__main__":
    test_flash_attn_padding()

imoneoi avatar Sep 15 '24 02:09 imoneoi