FA3 support for GB200
Hi team,
I'm trying to use FA3 for inference for WAN2.2, with GB200 chips, which are sm_100. But I'm seeing issues like torch.AcceleratorError: CUDA error: unspecified launch failure or some other errors. I wanted to check and confirm if its possible to use FA3 with Blackwell i.e. GB200 for inference, or is it only specific to Hopper?
FA2 works perfectly fine on the above setup, just facing issues with FA3.
If yes, are there any documentation you can share for the integration, that would be really helpful. @tridao , @johnnynunez
Hi team, I'm trying to use FA3 for inference for WAN2.2, with GB200 chips, which are sm_100. But I'm seeing issues like
torch.AcceleratorError: CUDA error: unspecified launch failureor some other errors. I wanted to check and confirm if its possible to use FA3 with Blackwell i.e. GB200 for inference, or is it only specific to Hopper?FA2 works perfectly fine on the above setup, just facing issues with FA3.
If yes, are there any documentation you can share for the integration, that would be really helpful. @tridao , @johnnynunez
did you try to compile flash cute? @tridao is moving all repository to cute dsl
cd flash-attention/flash_attn/cute
uv build --wheel . -v --no-build-isolation --out-dir flash-attention/wheels
uv pip install flash-attention/wheels/flash_attn_cute*.whl --prerelease=allow
For using the above version, will I need to update the references in the WAN code where it was using FA2 to reference this library?
Is my understanding correct. I tried building the above steps with FA2, but it didn't work, so wanted to confirm that.
For using the above version, will I need to update the references in the WAN code where it was using FA2 to reference this library?
Is my understanding correct. I tried building the above steps with FA2, but it didn't work, so wanted to confirm that.
fa2 you need to compile the library. But fa2 is not properly for Blackwell, it is better to use fa4
yeah, the above cute version you share is FA4 right? I tried the commands you shared, but when I ran the WAN model, its not able to use and reference the flash-attn cute files in the site-packages, will I need to update the WAN model code to reference the FA cute version interfaces to make it work with FA4?
Or are there any steps I might be missing. This is the error I'm getting:
Traceback (most recent call last):
File "/workspace/Wan2.2/generate.py", line 20, in <module>
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
File "/workspace/Wan2.2/wan/utils/prompt_extend.py", line 18, in <module>
from flash_attn import flash_attn_varlen_func
ImportError: cannot import name 'flash_attn_varlen_func' from 'flash_attn' (unknown location)
yeah, the above cute version you share is FA4 right? I tried the commands you shared, but when I ran the WAN model, its not able to use and reference the flash-attn cute files in the site-packages, will I need to update the WAN model code to reference the FA cute version interfaces to make it work with FA4?
Or are there any steps I might be missing. This is the error I'm getting:
Traceback (most recent call last): File "/workspace/Wan2.2/generate.py", line 20, in <module> from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander File "/workspace/Wan2.2/wan/utils/prompt_extend.py", line 18, in <module> from flash_attn import flash_attn_varlen_func ImportError: cannot import name 'flash_attn_varlen_func' from 'flash_attn' (unknown location)
yes, maybe you have to adapt to new call api.
FA3 does not work on Blackwell, for instance because it uses wgmma instructions, which are deprecated. For using FA4, the API is
from flash_attn.cute.interface import (
flash_attn_func,
flash_attn_varlen_func,
flash_attn_combine,
)
as seen for example in https://github.com/Dao-AILab/flash-attention/blob/main/tests/cute/test_flash_attn.py. This dispatches to whichever architecture you have set.
Yes, Thank you for the thoughts and instructions. So I updated the model to reference the cute impl (fa4) and since I was using the flash_atnn_varlen_func, moving from fa2 to fa4, I'm not seeing consistent results across the both.
Cute Variant:
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k: Optional[torch.Tensor] = None,
seqused_q: Optional[torch.Tensor] = None,
seqused_k: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Tuple[Optional[int], Optional[int]] = (None, None),
learnable_sink: Optional[torch.Tensor] = None,
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: Optional[bool] = None,
):
FA2
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
block_table=None,
):
As my workloads are working fine with FA2, I tried to integrate the FA4 (cute) version, as seen above, its missing the max_seqlen_q and max_seqlen_k in FA4, due to which I didn't pass them. But seems like the response from both of them is different in this scenario. And the FA4 integration isn't working properly. Is there any workaround way or am I missing something which can make the integration right.
Can you clarify what isn't working properly? Also FA4 varlen does not yet support max_seqlen_*.
Can you clarify what isn't working properly? Also FA4 varlen does not yet support
max_seqlen_*.
Can you clarify what isn't working properly -> The response from the flash_attn_varlen_func for FA2 and cute version.
When I updated the FA2 reference call in my model with the cute interface call, I had to omit sending max_seqlen_* params as FA4 doesn't support it (Which I feel might be the source of error). And when I compared both of the responses from both FA2 and FA4 flash_attn_varlen_func, the responses were not same, due to which my model (WAN2.2) was generating right videos when FA2 is used, but once I integrated FA4, the output video were just random pixels.
So, is it safe to assume that, if my FA2 variant was using max_seqlen_* and FA4 currently doesn't support it, so the results are expected to be different?
Any thoughts on the above @tridao ?
max_seqlen_* is no longer necessary in FA4. Are you running the fwd or fwd & bwd?
I'm only running the fwd pass. I tried to integrate FA4 here in place of FA2 call. Which gave me different results. The only thing that I can see differs between FA4 and FA2 calls are some of the arguments, I used the same args which were available in FA4 from FA2 call, and omitted the others like max-seqlen-q, dropout_p, deterministic as I'm running inference only. But the response was different from both the FA2 and FA4 calls, that's where I wanted some help if I'm missing some arg which might be causing this difference or is this result difference expected.
For more context, I verified the above with a script, calling both the FA2 and FA4 variant and comparing their outputs. And they doesn't match with 1e-4 tolerance. @tridao
FA2 & FA4 comparison script
import torch
import numpy as np
import os
try:
from flash_attn import flash_attn_varlen_func as fa_vanilla
except ImportError:
fa_vanilla = None
print("WARNING: Could not import fa_vanilla. Proceeding only if the other import succeeds.")
try:
from flash_attn.cute.interface import flash_attn_varlen_func as fa_cute
except ImportError:
fa_cute = None
print("WARNING: Could not import fa_cute. Proceeding only if the other import succeeds.")
def compare_attention_outputs():
print("--- 1. Setting up Test Parameters ---")
s = 35073 # Sequence length
h = 32 # Number of heads
d = 128 # Head dimension
device = "cuda"
dtype = torch.bfloat16
q = torch.randn(s, h, d, device=device, dtype=dtype)
k = torch.randn(s, h, d, device=device, dtype=dtype)
v = torch.randn(s, h, d, device=device, dtype=dtype)
cu_seqlens_q = torch.tensor([0, s], device=device, dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, s], device=device, dtype=torch.int32)
max_seqlen_k = s
print("\n--- 2. Running Kernels ---")
print("-> Running FA Vanilla kernel...")
out_fa_vanilla = fa_vanilla(
q=q, k=k, v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_k,
max_seqlen_k=max_seqlen_k
)
print("-> Running FA CuTe kernel...")
out_fa_cute = fa_cute(
q=q, k=k, v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k
)
print("\n--- 3. Numerical Validation ---")
TOLERANCE = 1e-4
out_fa_cute = out_fa_cute[0]
if out_fa_vanilla.shape != out_fa_cute.shape:
print(f"SHAPE MISMATCH: Vanilla {out_fa_vanilla.shape} vs CuTe {out_fa_cute.shape}")
return
diff = torch.abs(out_fa_vanilla - out_fa_cute)
max_abs_diff = diff.max().item()
mean_abs_diff = diff.mean().item()
print(f" Max Absolute Difference (MAE): {max_abs_diff:.5e}")
print(f" Mean Absolute Difference (MeanAE): {mean_abs_diff:.5e}")
print(f" Comparison Tolerance: {TOLERANCE:.5e}")
if max_abs_diff < TOLERANCE:
print("\nVALIDATION SUCCESSFUL: Outputs are numerically equivalent.")
print(" The optimized kernel is safe to use for deployment.")
else:
print("\n❌ VALIDATION FAILED: Outputs diverge beyond acceptable tolerance.")
print(" The optimized kernel may have a numerical or stability issue.")
if __name__ == "__main__":
compare_attention_outputs()
Output
--- 1. Setting up Test Parameters ---
--- 2. Running Kernels ---
-> Running FA Vanilla kernel...
-> Running FA CuTe kernel...
seq lengths 35073 35073
gmem_tiled_copy_O: Tiled Copy
Tiler MN: (8:1,64:1)
TV Layout tiled: ((8,8),8):((64,1),8)
Copy Atom
ThrID: 1:0
TV Layout Src: (1,8):(0,1)
TV Layout Dst: (1,8):(0,1)
Value type: bf16
--- 3. Numerical Validation ---
Max Absolute Difference (MAE): 2.44141e-04
Mean Absolute Difference (MeanAE): 1.58548e-05
Comparison Tolerance: 1.00000e-04
❌ VALIDATION FAILED: Outputs diverge beyond acceptable tolerance.
The optimized kernel may have a numerical or stability issue.