alphafold
alphafold copied to clipboard
Flash attention
Flash attention implemented to reduce runtime and memory usage using Pallas. Added on opt-in basis in the global config.
For a 759 residue protein and model_5 this drops peak memory consumption to 5 GB without minibatching and reduces runtime 2.3x on an A100 (15.2 $\rightarrow$ 6.5 seconds [with minibatching of 256 for non-flash attention to avoid OOM])
Here's a colab link showing runtime improvement and no significant change in prediction output by visual inspection
When combined with https://github.com/google-deepmind/alphafold/pull/930 (bfloat16 support for monomer models) peak memory drops to only 2.7 GB and runtime to 5.6 seconds (2.7x speedup relative to non-flash, float32)
Notes:
Key variations from a reference flash attention kernel are:
- Attention logit biasing supported
- Gating supported
- Some heads have only 8 channels, they’re padded up to 16 within kernel (this is a requirement of pl.dot, we still see performance improvement relative to non-flash attn and keeps overall AlphaFold2 linear in memory requirements)
- Broadcasted masks in batch, q and head dimensions supported (they’re often size 1 and implicitly broadcasted in AlphaFold2 einsums)
There's guards against kernel being called for short sequence lengths less than block sizes specified in q and k which exits to reference kernel.
I haven't done correctness checks with multimer models, I would do if there was a positive response to this pull request. I'm not certain on the numerical stability of the implementation yet with bfloat16
(I can switch out the exp and log for exp2 and log2 for a small reduction in runtime, this leads to slightly different predictions but with testing I believe would show equivalent error in structure prediction)
Hi @oliverdutton ! Really cool contribution. Mind we try add it to colabfold? We already have fused attention and bfloat16 integrated into monomer model. Will be interesting to try flash attention as well.
@sokrypton Of course, I've made a pull request in ColabDesign with it (https://github.com/sokrypton/ColabDesign/pull/173)
Pre https://github.com/google-deepmind/alphafold/pull/931/commits/d4516d83aaf65aee2e2c90ca85b86acacd464c0f I find transient NaN behaviour on shapes which don't evenly divide block size (so OOB loading).
gist to reproduce problem:
import jax
from jax import jit, numpy as jnp
from alphafold.model import model
key = jax.random.PRNGKey(42)
nrepeats = 100
for nres in range(128,256):
print(nres)
for i in range(nrepeats):
q, k, v = jax.random.uniform(key, (3, 1024, nres, 8, 32))
f = jax.jit(model.modules.Attention.flash_kernel, static_argnames=(
'return_residual', 'block_q', 'block_k', 'num_warps', 'num_stages', 'grid', 'interpret', 'debug')
)
assert jnp.isfinite(f(q,k,v)).all(), f"Failed with {nres} on run {i}"
Post https://github.com/google-deepmind/alphafold/pull/931/commits/d4516d83aaf65aee2e2c90ca85b86acacd464c0f transient NaN behaviour error disappears. So I hope this will now always be NaN free.
Thank you very much, this improvement is very useful. I am using RTX3090 to predict a 3645aa heterotetramer. With this improvement, the prediction time of a single model has decreased from 59,000 seconds to 43,000 seconds (also out of GPU memory limit).