long-context-attention
long-context-attention copied to clipboard
Incorrect version check for flash_attn leads to API incompatibility in v2.6.3
Description
The following error occurs when flash_attn == 2.6.3:
[rank0]: File "/home/xxx/miniconda3/envs/xxx/lib/python3.10/site-packages/yunchang/kernels/attention.py", line 132, in flash_attn_forward
[rank0]: block_out, block_lse, _, _ = _flash_attn_forward(
[rank0]: TypeError: _flash_attn_forward() got an unexpected keyword argument 'window_size_left'
Refering to the source code, this error originates from the inaccurate branch condition.
if flash_attn.__version__ < '2.6.3': # <-- WRONG!
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(
# ...
window_size=window_size,
# ...
)
else:
block_out, block_lse, _, _ = _flash_attn_forward(
# ...
window_size_left=window_size[0],
window_size_right=window_size[1],
# ...
)
To be more clear, the parameter window_size_left was first introduced in flash_attn 2.7.0, and in 2.6.3 the signature of _flash_attn_forward is still as follows:
def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
)
Solution
Correct the branch condition to flash_attn.__version__ <= '2.6.3' or other equivalent solutions.
Could you please give a PR help fix the bug?