flash-attention
flash-attention copied to clipboard
Fix +/-inf in LSE returned by forward
Forward op was returning +inf in LSE for queries which have no keys to attend to, e.g. when K/V length happens to be 0. This diverges from the definition of LSE = log(exp(L1) + ... exp(L2)) which would give log(0) = -inf.
This PR fixes it, which allows feeding the output LSE directly into ops like merge_attentions without postprocessing.
pytest tests/test_flash_attn.py
...
======================================================================================== 268004 passed, 152064 skipped in 4404.00s (1:13:23) =========================================================================================
One issue I can see is that in the backward pass, if lse = +inf then exp(qk - lse) returns 0, which is what we want. If lse = -inf then exp would blow up.
QQ: do we plan to merge this PR as it has been pending for months.
QQ: do we plan to merge this PR as it has been pending for months.
Sorry, I didn't follow-up on @tridao's comment above. Basically I think there should be no NaNs after this change because the code actually checks for -inf before computing exp(score - lse) in the backward pass
https://github.com/Dao-AILab/flash-attention/blob/40fa35acd8269ebc4777e682f8bfb690f1f12bb5/csrc/flash_attn/src/softmax.h#L75
Also, in the Hopper kernel we write -inf for out-of-bounds positions
https://github.com/Dao-AILab/flash-attention/blob/40fa35acd8269ebc4777e682f8bfb690f1f12bb5/hopper/epilogue_fwd.hpp#L379-L385
https://github.com/Dao-AILab/flash-attention/blob/40fa35acd8269ebc4777e682f8bfb690f1f12bb5/hopper/flash_fwd_kernel_sm90.h#L406
If that makes sense and FA2 code is still relevant, I add a test which cover backward behaviour in this situation to make the PR mergeable