flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Fix +/-inf in LSE returned by forward

Open sgrigory opened this issue 1 year ago • 1 comments
trafficstars

Forward op was returning +inf in LSE for queries which have no keys to attend to, e.g. when K/V length happens to be 0. This diverges from the definition of LSE = log(exp(L1) + ... exp(L2)) which would give log(0) = -inf. This PR fixes it, which allows feeding the output LSE directly into ops like merge_attentions without postprocessing.

pytest tests/test_flash_attn.py
...
======================================================================================== 268004 passed, 152064 skipped in 4404.00s (1:13:23) =========================================================================================

sgrigory avatar Jun 03 '24 14:06 sgrigory

One issue I can see is that in the backward pass, if lse = +inf then exp(qk - lse) returns 0, which is what we want. If lse = -inf then exp would blow up.

tridao avatar Jun 27 '24 09:06 tridao

QQ: do we plan to merge this PR as it has been pending for months.

GD06 avatar Jan 03 '25 03:01 GD06

QQ: do we plan to merge this PR as it has been pending for months.

Sorry, I didn't follow-up on @tridao's comment above. Basically I think there should be no NaNs after this change because the code actually checks for -inf before computing exp(score - lse) in the backward pass https://github.com/Dao-AILab/flash-attention/blob/40fa35acd8269ebc4777e682f8bfb690f1f12bb5/csrc/flash_attn/src/softmax.h#L75

Also, in the Hopper kernel we write -inf for out-of-bounds positions

https://github.com/Dao-AILab/flash-attention/blob/40fa35acd8269ebc4777e682f8bfb690f1f12bb5/hopper/epilogue_fwd.hpp#L379-L385

https://github.com/Dao-AILab/flash-attention/blob/40fa35acd8269ebc4777e682f8bfb690f1f12bb5/hopper/flash_fwd_kernel_sm90.h#L406

If that makes sense and FA2 code is still relevant, I add a test which cover backward behaviour in this situation to make the PR mergeable

sgrigory avatar Jan 10 '25 13:01 sgrigory