openfold
openfold copied to clipboard
Flash Attention seems broken
Hello,
Thanks for a very useful and well documented package. Great effort!
I have tried to train the model with flash attention enabled, and ran into the following error:
....
File "/bmm/home/et517/.conda/envs/openfold/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/bmm/home/et517/code/openfold/openfold/model/msa.py", line 284, in forward
flash_mask=mask,
File "/bmm/home/et517/.conda/envs/openfold/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/bmm/home/et517/code/openfold/openfold/model/primitives.py", line 502, in forward
o = _flash_attn(q, k, v, flash_mask)
File "/bmm/home/et517/code/openfold/openfold/model/primitives.py", line 718, in _flash_attn
softmax_scale = 1., # q has been scaled already
File "/bmm/home/et517/.conda/envs/openfold/lib/python3.7/site-packages/flash_attn/flash_attn_interface.py", line 301, in flash_attn_unpadded_kvpacked_func
return_attn_probs)
File "/bmm/home/et517/.conda/envs/openfold/lib/python3.7/site-packages/flash_attn/flash_attn_interface.py", line 98, in forward
max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax
File "/bmm/home/et517/.conda/envs/openfold/lib/python3.7/site-packages/flash_attn/flash_attn_interface.py", line 23, in _flash_attn_forward
softmax_scale, False, causal, return_softmax, num_splits, generator
RuntimeError: cu_seqlens_k must have shape (batch_size + 1)
The error comes from Flash Attention C code, but I think that it is caused by incorrect reshaping of kv_mask
in primitives.py:_flash_attn
. It might have been caused by some changes on their side. According to their github page, they have made a number of releases in the last few months. I am using 0.2.8
, which is the most recent right now.
I have corrected the mask shape on my branch and got rid of that particular error.
I have tried running the fixed version, but now the losses are NaN. Any ideas what might be causing this? It seems to run OK without Flash.
What kind of performance improvement should I be expecting, anyway? I am running on two servers with 4 Nvidia Quadro RTX 6000 each, and without flash each training batch takes about 1 minute, which seems a bit too slow.
Alternatively, if you could provide an older version of Flash that is proven to work, it would be much appreciated.
I haven't personally gotten FlashAttention training working (I see very similar issues/slow runtimes). It's currently included in the package as an inference-time optimization.
Thanks for the update. To clarify, the slow runtime (30 seconds to 1 minute per training step) is without flash attention. I thought the way to speed things up would be to enable flash attention. Should I be expecting anything faster on my hardware? Any tips on how to debug slow runtimes?
Are you using the "initial_training" setting?
Yes, I am