flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

FA3 No perf gain noticed.

Open AdamLouly opened this issue 1 year ago • 14 comments

I'm using flash attention in my code. and I installed FA3 and I noticed to perf gain when I run the model again. My CUDA is 12.4 and I'm using H100.

do we need to modify something or should we see the perf gain automatically?

AdamLouly avatar Jul 30 '24 15:07 AdamLouly

Same here — using a head dim of 256 in fp16 fwd/bwd

alexanderswerdlow avatar Jul 30 '24 20:07 alexanderswerdlow

FA3 is a beta release, you'd need to call a different interface to use it (for now). https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_attn_interface.py

As we roll it out (e.g. after v3.0.0 tag) eventually the old interface will automatically call the faster code.

tridao avatar Jul 30 '24 20:07 tridao

FA3 is a beta release, you'd need to call a different interface to use it (for now). https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_attn_interface.py

As we roll it out (e.g. after v3.0.0 tag) eventually the old interface will automatically call the faster code.

Thanks for the answer, so to make it clear. for example in this class:

https://github.com/Dao-AILab/flash-attention/blob/5018ac6ac531aabdb05c8af1ba3d98a2235bdbde/flash_attn/modules/mha.py#L53

do we replace these guys: flash_attn_varlen_qkvpacked_func flash_attn_qkvpacked_func

with these guys from the hopper folder

flash_attn_func flash_attn_qkvpacked_func

Thanks

AdamLouly avatar Jul 31 '24 09:07 AdamLouly

ya

tridao avatar Aug 01 '24 19:08 tridao

I haven't confirmed this, but the lack of speedup may be in part due to the contiguous requirement (#1087) as when I am benchmarking, I am comparing to using FA2 QKVPacked without forcing the tensors to be contiguous.

alexanderswerdlow avatar Aug 05 '24 21:08 alexanderswerdlow

FA3 now doesn't require tensors to be contiguous.

tridao avatar Aug 05 '24 21:08 tridao

Hm, I was just compiling with the most recent commit (3f1b4d38) and despite the maybe_contiguous, I found that I got an Input tensors must be contiguous error (like this comment mentions).

I found I need to do something like:

q, k, v = qkv.unbind(dim=2)
flash_attn_func_v3(q.contiguous(), k.contiguous(), v.contiguous())

alexanderswerdlow avatar Aug 05 '24 21:08 alexanderswerdlow

Please make sure you use the latest commit (e.g. uninstall previous versions). The newest code doesn't do any contiguous check anymore.

tridao avatar Aug 05 '24 21:08 tridao

Sorry I mispoke, let me update to remove those checks.

tridao avatar Aug 05 '24 21:08 tridao

done, thanks for checking

tridao avatar Aug 05 '24 21:08 tridao

Thanks! Just pulled and was able to run it without the forced .contiguous(). I'll note [mainly for other users, if they run into it] that I found flash_attn_qkvpacked_func [FA2] to still be quite a bit faster (~10-15%, end to end training in my case) than the regular flash_attn_func [FA2 or FA3] which was unexpected. Looking at the code, I wouldn't expect such a large difference as qkvpacked appears to be a thin wrapper, but these are the results I'm getting. Might be an oddity with my specific configuration [BS=12, Seqlen=1024, Headdim=128, Heads=16 on a single H100] but thought I should note it.

alexanderswerdlow avatar Aug 05 '24 22:08 alexanderswerdlow

qkvpacked avoids the concat in the backward pass (backward of unbind is concat).

tridao avatar Aug 05 '24 22:08 tridao

Also just pulled and tried removing the manual .contiguous() calls. I'm hitting

  File "/workspace/flash-attention/hopper/flash_attn_interface.py", line 46, in _flash_attn_backward
    dq, dk, dv, softmax_d, *rest = flashattn_hopper_cuda.bwd(
RuntimeError: dq must have contiguous last dimension

Does backward still require contiguous grads?

Fuzzkatt avatar Aug 05 '24 22:08 Fuzzkatt

As it says, contiguous last dimension. I.e. last stride should be 1.

tridao avatar Aug 05 '24 22:08 tridao