FA3 No perf gain noticed.
I'm using flash attention in my code. and I installed FA3 and I noticed to perf gain when I run the model again. My CUDA is 12.4 and I'm using H100.
do we need to modify something or should we see the perf gain automatically?
Same here — using a head dim of 256 in fp16 fwd/bwd
FA3 is a beta release, you'd need to call a different interface to use it (for now). https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_attn_interface.py
As we roll it out (e.g. after v3.0.0 tag) eventually the old interface will automatically call the faster code.
FA3 is a beta release, you'd need to call a different interface to use it (for now). https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_attn_interface.py
As we roll it out (e.g. after v3.0.0 tag) eventually the old interface will automatically call the faster code.
Thanks for the answer, so to make it clear. for example in this class:
https://github.com/Dao-AILab/flash-attention/blob/5018ac6ac531aabdb05c8af1ba3d98a2235bdbde/flash_attn/modules/mha.py#L53
do we replace these guys: flash_attn_varlen_qkvpacked_func flash_attn_qkvpacked_func
with these guys from the hopper folder
flash_attn_func flash_attn_qkvpacked_func
Thanks
ya
I haven't confirmed this, but the lack of speedup may be in part due to the contiguous requirement (#1087) as when I am benchmarking, I am comparing to using FA2 QKVPacked without forcing the tensors to be contiguous.
FA3 now doesn't require tensors to be contiguous.
Hm, I was just compiling with the most recent commit (3f1b4d38) and despite the maybe_contiguous, I found that I got an Input tensors must be contiguous error (like this comment mentions).
I found I need to do something like:
q, k, v = qkv.unbind(dim=2)
flash_attn_func_v3(q.contiguous(), k.contiguous(), v.contiguous())
Please make sure you use the latest commit (e.g. uninstall previous versions). The newest code doesn't do any contiguous check anymore.
Sorry I mispoke, let me update to remove those checks.
done, thanks for checking
Thanks! Just pulled and was able to run it without the forced .contiguous(). I'll note [mainly for other users, if they run into it] that I found flash_attn_qkvpacked_func [FA2] to still be quite a bit faster (~10-15%, end to end training in my case) than the regular flash_attn_func [FA2 or FA3] which was unexpected. Looking at the code, I wouldn't expect such a large difference as qkvpacked appears to be a thin wrapper, but these are the results I'm getting. Might be an oddity with my specific configuration [BS=12, Seqlen=1024, Headdim=128, Heads=16 on a single H100] but thought I should note it.
qkvpacked avoids the concat in the backward pass (backward of unbind is concat).
Also just pulled and tried removing the manual .contiguous() calls. I'm hitting
File "/workspace/flash-attention/hopper/flash_attn_interface.py", line 46, in _flash_attn_backward
dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd(
RuntimeError: dq must have contiguous last dimension
Does backward still require contiguous grads?
As it says, contiguous last dimension. I.e. last stride should be 1.