ColabDesign icon indicating copy to clipboard operation
ColabDesign copied to clipboard

Flash attention

Open oliverdutton opened this issue 10 months ago • 1 comments

Implements FlashAttention similarly to https://github.com/google-deepmind/alphafold/pull/931

For a 759 residue protein and model_5 this improves runtime 2.2x on an L4 (37.3 $\rightarrow$ 16.9 seconds [with minibatching of 256 for non-flash attention to avoid OOM])

Here's a colab link showing runtime improvement and no significant change in prediction output by visual inspection. I didn't want to rerun all the input prep so I've used a colab with alphafold input preparation and done fixes for colabdesign.

Notes

Key variations from a reference flash attention kernel are:

  • Attention logit biasing supported
  • Gating supported
  • Some heads have only 8 channels, they’re padded up to 16 within kernel (this is a requirement of pl.dot, we still see performance improvement relative to non-flash attn and keeps overall AlphaFold2 linear in memory requirements)
  • Broadcasted masks in batch, q and head dimensions supported (they’re often size 1 and implicitly broadcasted in AlphaFold2 einsums)

There's guards against kernel being called for short sequence lengths less than block sizes specified in q and k which exits to reference kernel.

Comments

  • I think the runtime improvement is benefitting from the triangular fusion you've previously implemented, as on an A100 I saw with flash attention and bfloat16 that starts to be significant.
  • I haven't done correctness/speed checks with multimer models or models using templates. If you have a test suite to do that, that'd be wonderful.
  • When you said 'fused attention' you meant shifting the mask to a bias so XLA lowers it to a fused kernel, right? I've moved that mask $\rightarrow$ bias conversion into the Attention module itself and kept it in the reference_kernel (so now reference_kernel differs from the one in google-deepmind/alphafold#931). So with use_flash_attention=False I haven't changed behaviour: here's a colab showing same 37.3s runtime from the main branch.
  • fix for use_dgram which seemed to access the wrong config keys
  • fix for models not containing pae head

oliverdutton avatar Apr 21 '24 11:04 oliverdutton

@sokrypton I think this is ready for merging.

It's still strictly opt-in (as Pallas with Triton is only available for Ampere architecture GPUs and up)

You could improve performance a bit more by tuning block sizes and the number of warps on an input shape dependent manner, and similarly the 'subbatch_size` global config setting could be split into a default heuristic of memory usage where it selects subbatch sizes

oliverdutton avatar May 05 '24 21:05 oliverdutton