flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

FA3 varlen_bwd hangs (FA2 works in the same case)

Open goldhuang opened this issue 1 year ago • 15 comments

from einops import rearrange
import torch    
import flashattn_hopper_cuda

def get_cu_seqlens(seqlens_in_batch):
    if isinstance(seqlens_in_batch, list):
        seqlens_in_batch = torch.tensor(seqlens_in_batch)
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
    return (
        cu_seqlens,
        max_seqlen_in_batch,
    )

seqlens = 1023
d_head=128
d_model=4096
dtype = torch.bfloat16
q_mat = torch.randn(d_model, d_model).cuda().to(dtype=dtype)
k_mat = torch.randn(d_model, d_model).cuda().to(dtype=dtype)
v_mat = torch.randn(d_model, d_model).cuda().to(dtype=dtype)

x = torch.randn(seqlens*2-1, d_model).cuda().to(dtype=dtype).requires_grad_(True)
# x = 2 * (x - x.min()) / (x.max() - x.min()) - 1
q = rearrange(x @ q_mat, "... (nh dh) -> ... nh dh", dh=d_head).to(dtype=dtype)
k = rearrange(x @ k_mat, "... (nh dh) -> ... nh dh", dh=d_head).to(dtype=dtype)
v = rearrange(x @ v_mat, "... (nh dh) -> ... nh dh", dh=d_head).to(dtype=dtype)

cu_seqlens_q, max_seqlen_q = get_cu_seqlens([0, 0, seqlens])
cu_seqlens_kv, max_seqlen_kv = get_cu_seqlens([0, 1, 2*seqlens-1])
cu_seqlens_q = cu_seqlens_q.to(q.device)
cu_seqlens_kv = cu_seqlens_kv.to(q.device)

class FA3_Varlen(torch.autograd.Function):
    @staticmethod
    def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k):
        q_var = q[0]
        k_var = k[0]
        v_var = v[0]
        ctx.softmax_scale = q.shape[-1] ** (-0.5)
        ctx.max_seqlen_q = max_seqlen_q
        ctx.max_seqlen_k = max_seqlen_k
        out_var, _, _, _, _, lse_var = flashattn_hopper_cuda.varlen_fwd(q_var, k_var, v_var, None, cu_seqlens_q, cu_seqlens_k, None, max_seqlen_q, max_seqlen_k, ctx.softmax_scale, False)
        ctx.save_for_backward(out_var, lse_var, q_var, k_var, v_var, cu_seqlens_q, cu_seqlens_k)
        return out_var.unsqueeze(0)

    @staticmethod
    def backward(ctx, grad_output):
        # Use the saved input from the forward pass
        out, lse, q, k, v, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors
        dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
        grad_output_squeezed = torch.squeeze(grad_output, dim=0)
        dq, dk, dv, *rest = flashattn_hopper_cuda.varlen_bwd(grad_output_squeezed, q, k, v, out, lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.softmax_scale, False, False)
        dq = dq.unsqueeze(0)
        dk = dk.unsqueeze(0)
        dv = dv.unsqueeze(0)
        return dq, dk, dv, None, None, None, None
    
y = FA3_Varlen.apply(q[None], k[None], v[None], cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)[0]
print ("FWD done")
y.sum().backward()
print ("BWD done")

FA3 varlen bwd hangs forever in this case, while FA2 varlen bwd is working. I debugged the cuda kernel and find that it falls into infinite loop at the producer-consumer logic.

goldhuang avatar Oct 03 '24 05:10 goldhuang

Which cuda version are you using?

tridao avatar Oct 03 '24 16:10 tridao

@tridao torch:, 2.4.0+cu124 nvcc:, V12.4.131

goldhuang avatar Oct 03 '24 17:10 goldhuang

varlen hang happening to me too on flashattn-hopper==3.0.0b1 (the wheel distributed alongside flash_attn==2.7.2.post1).

using CUDA 12.8, H100, pytorch 2.6.0, driver 535.216.01.

if seqused is None, backwards pass hangs.
if seqused is non-None and all tokens are used, backwards pass hangs.
if seqused is non-None and has sparsity, backwards pass does not hang but I get NaN gradients.
forward pass gets correct results.

if I invoke FA2 (flash_attn==v2.7.4.post1)'s varlen APIs with the same cu_seqlens: forward pass gets correct results, and backwards pass doesn't hang and grads are finite.

problem reproduces without using GQA (I did have one GQA varlen in my model, but I demoted that one to FA2, and the hang still occurs. the other non-GQA varlen functions are likely culprits).

head dim is 128.

Birch-san avatar Feb 17 '25 20:02 Birch-san

@tridao torch:, 2.4.0+cu124 nvcc:, V12.4.131

You should try the latest version. It works fine for me. Btw your sequence lengths aren't right since x has seqlens*2-1 but you cu_seqlens_kv is calculated to be [0, 0, 1, 2*seqlens] so that requires k & v to have length seqlens*2.

tridao avatar Feb 17 '25 22:02 tridao

varlen hang happening to me too on flashattn-hopper==3.0.0b1 (the wheel distributed alongside flash_attn==2.7.2.post1).

using CUDA 12.8, H100, pytorch 2.6.0, driver 535.216.01.

if seqused is None, backwards pass hangs. if seqused is non-None and all tokens are used, backwards pass hangs. if seqused is non-None and has sparsity, backwards pass does not hang but I get NaN gradients. forward pass gets correct results.

if I invoke FA2 (flash_attn==v2.7.4.post1)'s varlen APIs with the same cu_seqlens: forward pass gets correct results, and backwards pass doesn't hang and grads are finite.

problem reproduces without using GQA (I did have one GQA varlen in my model, but I demoted that one to FA2, and the hang still occurs. the other non-GQA varlen functions are likely culprits).

head dim is 128.

You should try the latest version, which is a big rewrite of FA3 from the 2.7.2 tag (the big rewrite happened between 2.7.2 and 2.7.3, sorry the tags are weird since the tags refer to the version of FA2 not FA3). The latest version of FA3 is much better tested and should be a lot more robust.

tridao avatar Feb 17 '25 22:02 tridao

roger that.

we originally went with 2.7.2.post1 because under 2.7.4.post1 we couldn't find a way to install FA2 and FA3 simultaneously.
the older convention was that FA2 and FA3 were packaged with different names (flash_attn and flashattn-hopper), and both could be installed together.

what's the new convention? is the idea that we would now just install one wheel, flash_attn? does it provide access to both FA2 and FA3? or is it just FA3 now?

does 2.7.4.post1 provide access to both new and old API conventions? currently I use both the APIs (but only because I have a use_fa3 boolean for switching backend, whilst I parity-test the implementations).

# 2.7.2.post1 convention:

# new APIs
from flashattn_hopper.flash_attn_interface import (
    flash_attn_varlen_func,
    flash_attn_func,
)

# old APIs
from flash_attn.flash_attn_interface import (
    flash_attn_func,
    flash_attn_kvpacked_func,
    flash_attn_qkvpacked_func,
    flash_attn_varlen_func,
    flash_attn_varlen_kvpacked_func,
    flash_attn_varlen_qkvpacked_func,
)

Birch-san avatar Feb 17 '25 23:02 Birch-san

I see, I've just pushed a commit to avoid name collision. At some point FA3 will replace FA2 but for now they should be able to co-exist.

tridao avatar Feb 18 '25 15:02 tridao

ah that's fantastic thanks; we'll try that. just waiting for image to finish building.

Birch-san avatar Feb 19 '25 00:02 Birch-san

okay, we have a v2.7.4.post1-ish (b36ad4ef767d2d5536ff8af2e3f720ae4eba731c) image now.

if seqused is None, backwards pass now gives me finite values without hanging. if seqused is non-None and all tokens are used, backwards pass gives me non-finite gradients. if seqused is non-None and has sparsity, backwards pass gives me non-finite gradients.

forward pass gets correct results.

err, actually I discovered a (perhaps unrelated) problem with forward pass, on both v2.7.2.post1 and v2.7.4.post1, that only occurs when FA3 varlen APIs are used inside torch.compile.
the forward pass gives me non-finite values if I repeat an identical workload enough times. this only happened when I used a non-None seqused. this was on a self-attention, so perhaps the significance was that it used seqlens_q. I tried to make a minimal repro but it didn't reproduce with the same shapes in a smaller example.
the problem also happened when using FA2 operations inside a compiled model, but since seqused does not exist in FA2, the problem occurred despite not using seqused.

Birch-san avatar Mar 11 '25 01:03 Birch-san

I should clarify in the docs that seqused is only for the cases when you know what you're doing (e.g. cache length during decoding, or splitting when doing context parallel). At "unused" locations the memory is uninitialized so they could be Inf/NaN.

tridao avatar Mar 11 '25 05:03 tridao

yes, we use seqused in image-generation to avoid attending to pad tokens in the text condition.
I'll describe the NaNs I'm getting on forward pass in more detail, to hopefully evidence that we're using seqused for a valid use-case, and that it's working to an extent, but fails under compilation.
this is (probably) a different issue to the NaN gradients exhibited on the backward pass, but perhaps it's a useful clue.

in image-text cross-attention there's no problems. we can safely use seqused here for forward pass.
in image-text multimodal self-attention, there's intermittent problems in the forward pass iff we enable torch.compile.

by "image-text multimodal self-attention", I mean like in Flux:

  • image and text are concatenated into a 1D sequence
  • the sequence self-attends
  • some tokens are padding because they're unused text tokens
    • since this is self-attention, that includes query tokens. so we use seqused_q and seqused_k.
  • the 1D sequence is split back into image and text again
  • an image residual and a text residual continue through the model
  • eventually the text residual is discarded, but the image residual is used for outputting the final image
  • ultimately, nowhere in the architecture attends to pad tokens, and no padding is returned from the model because text is not returned by the model. it should be fine for these positions to contain uninitialized memory.

I'll show the intermittent problems we get, inferencing the same workload repeatedly.

with seqused and with torch compilation
we get black images (all values NaN) sometimes, with occasional recovery. the behaviour of which image in the sequence ends up black, seems to be deterministic across runs of the program:
Image

looking into the intermediate results during the creation of image 007, we see that it survives 6 forward passes before the NaNs set in:
Image

with seqused and without torch compilation
works fine:
Image

without seqused, with torch compilation
works fine:
Image

Birch-san avatar Mar 11 '25 14:03 Birch-san

Do you have an example what what seqused_q and seqused_k look like? And what are the shapes of q, k, v? I might add an option to zero out the output before attn fwd computation or zero out the gradient before the attn bwd computation so that there's no uninitialized values (at the cost of small slowdown).

tridao avatar Mar 11 '25 18:03 tridao

I tried to make a reproducer of the attentions which trigger the black images problem in forward pass, but alas this script does not trigger the problem.
the shapes/lengths/dimensions do match the model which encounters the problem though, so maybe it will be informative. they're gonna be very weird shapes because the image sequence can be varied resolutions / aspect ratios and undergoes downsampling, then gets concatenated to a text sequence.

the model which experiences NaN on the backward pass is a different model and I haven't isolated which operation is sensitive. the shapes / lengths it deals with are different to those in this reproducer.

import torch
from torch.nn import Module, Linear
from torch import FloatTensor, IntTensor, inference_mode
from tqdm import trange
from einops import rearrange
from typing import Optional, NamedTuple

from flashattn_hopper.flash_attn_interface import flash_attn_varlen_func


class SelfAttention(Module):
    def __init__(self):
        super().__init__()

    def forward(
        self,
        q: FloatTensor,
        k: FloatTensor,
        v: FloatTensor,
        cu_seqlens: IntTensor,
        max_seqlen: int,
        seqused: IntTensor,
    ) -> FloatTensor:
        out, _ = flash_attn_varlen_func(
            q=q,
            k=k,
            v=v,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=max_seqlen,
            seqused_q=seqused,
            seqused_k=seqused,
        )
        return out


class MMOut(NamedTuple):
    a: FloatTensor
    b: FloatTensor


class MMAttention(Module):
    def __init__(
        self,
        a_in_dim: int,
        b_in_dim: int,
        heads: int,
        head_dim: int,
        device: Optional[torch.device | str | int] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.heads = heads
        hidden_dim = heads * head_dim
        self.a_qkv_proj = Linear(
            in_features=a_in_dim,
            out_features=hidden_dim * 3,
            bias=False,
            **factory_kwargs,
        )
        self.b_qkv_proj = Linear(
            in_features=b_in_dim,
            out_features=hidden_dim * 3,
            bias=False,
            **factory_kwargs,
        )
        self.a_out_proj = Linear(
            in_features=hidden_dim,
            out_features=a_in_dim,
            bias=False,
            **factory_kwargs,
        )
        self.b_out_proj = Linear(
            in_features=hidden_dim,
            out_features=b_in_dim,
            bias=False,
            **factory_kwargs,
        )

    def forward(
        self,
        a: FloatTensor,
        b: FloatTensor,
        b_seqused: IntTensor,
    ) -> MMOut:
        h, w = a.shape[-3:-1]
        a = rearrange(a, "... h w c -> ... (h w) c")

        a_qkv: FloatTensor = self.a_qkv_proj(a)
        a_q, a_k, a_v = rearrange(
            a_qkv,
            "... (proj n_heads head_dim) -> ... proj n_heads head_dim",
            proj=3,
            n_heads=self.heads,
        ).unbind(-3)

        b_qkv: FloatTensor = self.b_qkv_proj(b)
        b_q, b_k, b_v = rearrange(
            b_qkv,
            "... (proj n_heads head_dim) -> ... proj n_heads head_dim",
            proj=3,
            n_heads=self.heads,
        ).unbind(-3)

        q = torch.cat([a_q, b_q], dim=-3)
        k = torch.cat([a_k, b_k], dim=-3)
        v = torch.cat([a_v, b_v], dim=-3)

        seqused: IntTensor = b_seqused + a.size(-2)
        max_seqlen: int = q.size(-3)
        cu_seqlens: IntTensor = torch.tensor(
            [0, max_seqlen], device=q.device, dtype=torch.int32
        )

        *batch_dims, _, _ = q.shape
        q, k, v = (
            rearrange(proj, "... n_heads head_dim -> (...) n_heads head_dim")
            for proj in (q, k, v)
        )

        out: FloatTensor
        out, _ = flash_attn_varlen_func(
            q=q,
            k=k,
            v=v,
            cu_seqlens_q=cu_seqlens,
            cu_seqlens_k=cu_seqlens,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=max_seqlen,
            seqused_q=seqused,
            seqused_k=seqused,
        )
        out = out.unflatten(0, batch_dims)

        out = rearrange(out, "... n_heads head_dim -> ... (n_heads head_dim)")

        out_a, out_b = out.tensor_split((a_q.size(-3),), dim=-2)

        a_o: FloatTensor = self.a_out_proj(out_a)
        b_o: FloatTensor = self.b_out_proj(out_b)

        a_o = rearrange(a_o, "... (h w) c -> ... h w c", h=h, w=w)

        return MMOut(a_o, b_o)


class Model(Module):
    def __init__(
        self,
        device: Optional[torch.device | str | int] = None,
        dtype: Optional[torch.dtype] = None,
    ):
        super().__init__()
        factory_kwargs = {"device": device, "dtype": dtype}
        self.mm0 = MMAttention(
            a_in_dim=1024,
            b_in_dim=2048,
            heads=8,
            head_dim=128,
            **factory_kwargs,
        )
        self.down_proj = Linear(
            in_features=1024,
            out_features=640,
            bias=False,
            **factory_kwargs,
        )
        self.mm1 = MMAttention(
            a_in_dim=2560,
            b_in_dim=2048,
            heads=20,
            head_dim=128,
            **factory_kwargs,
        )
        self.up_proj = Linear(
            in_features=2560,
            out_features=4096,
            bias=False,
            **factory_kwargs,
        )
        self.mm2 = MMAttention(
            a_in_dim=1024,
            b_in_dim=2048,
            heads=8,
            head_dim=128,
            **factory_kwargs,
        )

    def forward(
        self,
        a: FloatTensor,
        b: FloatTensor,
        b_seqused: IntTensor,
    ) -> FloatTensor:
        out: MMOut = self.mm0(
            a=a,
            b=b,
            b_seqused=b_seqused,
        )
        a, b = out
        a = self.down_proj(a)
        a = rearrange(a, "... (h nh) (w nw) e -> ... h w (nh nw e)", nh=2, nw=2)
        out = self.mm1(
            a=a,
            b=b,
            b_seqused=b_seqused,
        )
        a, b = out
        a = self.up_proj(a)
        a = rearrange(a, "... h w (nh nw e) -> ... (h nh) (w nw) e", nh=2, nw=2)
        out = self.mm2(
            a=a,
            b=b,
            b_seqused=b_seqused,
        )
        a, b = out
        return a


device = torch.device("cuda")
dtype = torch.float16

torch.manual_seed(42)

model = Model(
    device=device,
    dtype=dtype,
).eval()

bsz = 1
a_seqh = 92
a_seqw = 92
b_seq = 512
a_dim = model.mm0.a_qkv_proj.in_features  # 1024
b_dim = model.mm0.b_qkv_proj.in_features  # 2048
seq = a_seqh * a_seqw + b_seq  # 8976
b_seqused = 364
seqused = a_seqh * a_seqw + b_seqused  # 8612

generator = torch.Generator(device=device).manual_seed(42)
a = torch.randn(
    bsz, a_seqh, a_seqw, a_dim, device=device, dtype=dtype, generator=generator
)
b = torch.randn(bsz, b_seq, b_dim, device=device, dtype=dtype, generator=generator)

cu_seqlens = torch.tensor([0, seq], device=device, dtype=torch.int32)
b_seqused_t = torch.tensor([b_seqused], device=device, dtype=torch.int32)

model = torch.compile(model, dynamic=False)

with inference_mode():
    for _ in trange(1024 * 20):
        out: FloatTensor = model(
            a=a,
            b=b,
            b_seqused=b_seqused_t,
        )
        assert out.isfinite().all()

Birch-san avatar Mar 11 '25 18:03 Birch-san

I might add an option to zero out the output before attn fwd computation or zero out the gradient before the attn bwd computation so that there's no uninitialized values (at the cost of small slowdown).

thanks for this, but I'm not sure how it would help. is the thesis here that despite my efforts I must still somehow be attending to pad tokens, or attention outputs in pad query positions?

I'm not sure how the images I'm outputting could be "identical often but black otherwise" if they were consistently attending to or returning uninitialized padding. I'm likewise not sure how to explain why the forward pass only fails under torch compilation, and even then only sometimes.

we're also now experiencing intermittent black images with another model we deployed, except this time the problem persists even if I avoid using seqused. need to dig into it to figure out which operation is sensitive.

Birch-san avatar Mar 12 '25 19:03 Birch-san

My original hypothesis is that the memory locations in the output & gradients corresponding to unused tokens are uninitialized. You might get lucky and most of the time these locations happen to be zero, so you'd still get correct answers. To be safe one typically has to explicitly zero them out.

Other usual causes of nondeterministic incorrect answers are race conditions and illegal memory access. I think I've tested those pretty extensively but one can never rule out race conditions. These are hard to reproduce. I typically run the same kernel 1000 times to see if they all produce the same result. https://github.com/Dao-AILab/flash-attention/blob/cdaf2de6e95cb05400959b5ab984f66e4c7df317/hopper/test_flash_attn.py#L991

tridao avatar Mar 12 '25 19:03 tridao