openfold icon indicating copy to clipboard operation
openfold copied to clipboard

Flash Attention seems broken

Open tanhevg opened this issue 1 year ago • 4 comments

Hello,

Thanks for a very useful and well documented package. Great effort!

I have tried to train the model with flash attention enabled, and ran into the following error:

....
  File "/bmm/home/et517/.conda/envs/openfold/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/bmm/home/et517/code/openfold/openfold/model/msa.py", line 284, in forward
    flash_mask=mask,
  File "/bmm/home/et517/.conda/envs/openfold/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/bmm/home/et517/code/openfold/openfold/model/primitives.py", line 502, in forward
    o = _flash_attn(q, k, v, flash_mask)
  File "/bmm/home/et517/code/openfold/openfold/model/primitives.py", line 718, in _flash_attn
    softmax_scale = 1., # q has been scaled already
  File "/bmm/home/et517/.conda/envs/openfold/lib/python3.7/site-packages/flash_attn/flash_attn_interface.py", line 301, in flash_attn_unpadded_kvpacked_func
    return_attn_probs)
  File "/bmm/home/et517/.conda/envs/openfold/lib/python3.7/site-packages/flash_attn/flash_attn_interface.py", line 98, in forward
    max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
  File "/bmm/home/et517/.conda/envs/openfold/lib/python3.7/site-packages/flash_attn/flash_attn_interface.py", line 23, in _flash_attn_forward
    softmax_scale, False, causal, return_softmax, num_splits, generator
RuntimeError: cu_seqlens_k must have shape (batch_size + 1)

The error comes from Flash Attention C code, but I think that it is caused by incorrect reshaping of kv_mask in primitives.py:_flash_attn. It might have been caused by some changes on their side. According to their github page, they have made a number of releases in the last few months. I am using 0.2.8, which is the most recent right now.

I have corrected the mask shape on my branch and got rid of that particular error.

I have tried running the fixed version, but now the losses are NaN. Any ideas what might be causing this? It seems to run OK without Flash.

What kind of performance improvement should I be expecting, anyway? I am running on two servers with 4 Nvidia Quadro RTX 6000 each, and without flash each training batch takes about 1 minute, which seems a bit too slow.

Alternatively, if you could provide an older version of Flash that is proven to work, it would be much appreciated.

tanhevg avatar Feb 01 '23 23:02 tanhevg