ColabDesign
ColabDesign copied to clipboard
Flash attention
Implements FlashAttention similarly to https://github.com/google-deepmind/alphafold/pull/931
For a 759 residue protein and model_5 this improves runtime 2.2x on an L4 (37.3 $\rightarrow$ 16.9 seconds [with minibatching of 256 for non-flash attention to avoid OOM])
Here's a colab link showing runtime improvement and no significant change in prediction output by visual inspection. I didn't want to rerun all the input prep so I've used a colab with alphafold input preparation and done fixes for colabdesign.
Notes
Key variations from a reference flash attention kernel are:
- Attention logit biasing supported
- Gating supported
- Some heads have only 8 channels, they’re padded up to 16 within kernel (this is a requirement of pl.dot, we still see performance improvement relative to non-flash attn and keeps overall AlphaFold2 linear in memory requirements)
- Broadcasted masks in batch, q and head dimensions supported (they’re often size 1 and implicitly broadcasted in AlphaFold2 einsums)
There's guards against kernel being called for short sequence lengths less than block sizes specified in q and k which exits to reference kernel.
Comments
- I think the runtime improvement is benefitting from the triangular fusion you've previously implemented, as on an A100 I saw with flash attention and bfloat16 that starts to be significant.
- I haven't done correctness/speed checks with multimer models or models using templates. If you have a test suite to do that, that'd be wonderful.
- When you said 'fused attention' you meant shifting the mask to a bias so XLA lowers it to a fused kernel, right? I've moved that mask $\rightarrow$ bias conversion into the Attention module itself and kept it in the reference_kernel (so now reference_kernel differs from the one in google-deepmind/alphafold#931). So with
use_flash_attention=False
I haven't changed behaviour: here's a colab showing same 37.3s runtime from the main branch. - fix for use_dgram which seemed to access the wrong config keys
- fix for models not containing pae head
@sokrypton I think this is ready for merging.
It's still strictly opt-in (as Pallas with Triton is only available for Ampere architecture GPUs and up)
You could improve performance a bit more by tuning block sizes and the number of warps on an input shape dependent manner, and similarly the 'subbatch_size` global config setting could be split into a default heuristic of memory usage where it selects subbatch sizes