Flash Attention
I think having flash attention in equinox should be a critical issue considering this is already natively built-in torch.
While XLA is supposed to (in theory) do some of the fusion, and possibly tiling as well - pallas offers calls under jax.experimental that are device specific (GPU/TPU) and possibly perform much better.
-
TPU: https://github.com/google/jax/blob/118ca21b5b9de069ff2abc81bc16c1ccbcf236ab/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
-
GPU: https://github.com/google/jax/blob/118ca21b5b9de069ff2abc81bc16c1ccbcf236ab/jax/experimental/pallas/ops/attention.py#L166
Perhaps it would be better to integrate them. As for handling the different devices, I'm not sure theres any good solution except to just get the device type at compile time with a match expr. to route the attention calculation to the correct kernel.
I'll also investigate and check if it actually offers any tangible benefit in throughput and report my results here.
Sounds good! This has been on my wishlist a long time, but I've never sat down to sort it out.
FWIW I believe XLA already does some level of algebraic pattern-matching here but I'm not sure of the details.
One difficulty might be the per-device dispatch, I think that might require a custom primitive? (I've also not looked up what Flax etc do here.)
Average naive MHSA time: 0.27780s
Average Pallas flash_attention time: 0.08993s
This 3x speedup on a TPUv3-8 (hyperparams in the script) which seems consistent with what other jax people have reported to me.
I used a naive MHSA impl. (courtesy of GPT4) since from a quick skim equinox doesn't do anything special.
(btw, why
https://github.com/patrick-kidger/equinox/blob/2bbedf8bd23cd5fe8eb5a83793c19c5a77fe099a/equinox/nn/_attention.py#L25
when you can just do query @ key.T?)
Here's the benchmarking script. LMK if there's some problem here, since I threw it together pretty quickly.
Details
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
from typing import *
from jaxtyping import Array, Float, Bool
import jax
import math
import time
import jax.numpy as jnp
@jax.jit
def scaled_dot_product_attention(Q, K, V):
"""
Calculates the attention weights and returns the output after applying these weights to V.
Args:
- Q, K, V are the query, key, and value tensors respectively.
"""
# Compute the dot product, scaled by the square root of the depth of K
matmul_qk = jnp.einsum('tih,tjh->tij', Q, K) # [seqlen, n_head, seqlen]
dim_k = K.shape[-1]
scaled_attention_logits = matmul_qk / jnp.sqrt(dim_k)
# Apply softmax to get the weights on the values
weights = jax.nn.softmax(scaled_attention_logits, axis=-1) # [seqlen, n_head, seqlen]
# Apply the weights to the values
output = jnp.einsum('tij,tjh->tih', weights, V) # [seqlen, n_head, embed_dim]
return output
@jax.jit
def vanilla_attn(Q, K, V):
"""
Perform multi-head self-attention on the inputs Q, K, V.
Args:
- Q, K, V: Input tensors with dimensions [seqlen, n_head, embed_dim].
"""
return scaled_dot_product_attention(Q, K, V)
def generate_data(batch_size: int, seq_len: int, embed_dim: int, num_heads: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray]]:
"""Utility function to generate dummy data for attention mechanism."""
scale = 1.0 / math.sqrt(embed_dim)
q = jnp.array(jax.random.normal(jax.random.PRNGKey(0), (batch_size, num_heads, seq_len, embed_dim)) * scale)
k = jnp.array(jax.random.normal(jax.random.PRNGKey(1), (batch_size, num_heads, seq_len, embed_dim)) * scale)
v = jnp.array(jax.random.normal(jax.random.PRNGKey(2), (batch_size, num_heads, seq_len, embed_dim)) * scale)
return q, k, v
def test_performance():
n: int = 20
batch_size, seq_len, embed_dim, num_heads = 256, 512, 128, 8
# Testing Equinox's MHSA
start_time = time.time()
for _ in range(n):
q, k, v = generate_data(batch_size, seq_len, embed_dim, num_heads)
_ = jax.vmap(vanilla_attn)(q, k, v).block_until_ready()
equinox_time = time.time() - start_time
# Testing Pallas's flash_attention
start_time = time.time()
for _ in range(n):
q, k, v = generate_data(batch_size, seq_len, embed_dim, num_heads)
_ = jax.jit(flash_attention)(q, k, v).block_until_ready()
pallas_time = time.time() - start_time
print(f"Average Equinox MHSA time: {equinox_time / n:.5f}s")
print(f"Average Pallas flash_attention time: {pallas_time / n:.5f}s")
if __name__ == "__main__":
test_performance()
Oh, that is a pretty awesome speedup. Sounds like we should find a way to get this in :)
As for why the einsum, just that I find it more readable!
On the benchmarking script: I think this could be tweaked slightly to move the vmap inside of jit, and to avoid reapplying the jax.jit decorator multiple times inside the loop. Both of those are fairly minor overheads though. A slightly more important overhead might be the compile time: a JIT'd function should be called once before the loop so that the compilation isn't included in the measurements.
I made a few optimization changes, and have been trying to figure this out for quite a bit of time now..
But the times I'm getting are very close. I asked someone else to test it on different hardware (thinking this was a TPU thing) but they get similar times on a RTX3090.
Average Naive MHSA time: 0.00341s
Average Pallas flash_attention time: 0.00429s
This does make a few changes - vmap is inside of JIT, I'm generating unique data per loop (pregenerated beforehand), I ensured everything is JIT-ed and the inputs are blocked_until_ready()-ied.
Not sure whether XLA is somehow pattern matching that I'm trying to do attention... or maybe I'm just doing something stupid here?
Updated script
import math
import time
from typing import Optional, Tuple
import jax
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.flash_attention import flash_attention
@jax.jit
def scaled_dot_product_attention(Q, K, V):
"""
Calculates the attention weights and returns the output after applying these weights to V.
Args:
- Q, K, V are the query, key, and value tensors respectively.
"""
# Compute the dot product, scaled by the square root of the depth of K
matmul_qk = jnp.einsum('tih,tjh->tij', Q, K) # [seqlen, n_head, seqlen]
dim_k = K.shape[-1]
scaled_attention_logits = matmul_qk / jnp.sqrt(dim_k)
# Apply softmax to get the weights on the values
weights = jax.nn.softmax(scaled_attention_logits, axis=-1) # [seqlen, n_head, seqlen]
# Apply the weights to the values
output = jnp.einsum('tij,tjh->tih', weights, V) # [seqlen, n_head, embed_dim]
return output
@jax.jit
def vanilla_attn(Q, K, V):
"""
Perform multi-head self-attention on the inputs Q, K, V.
Args:
- Q, K, V: Input tensors with dimensions [seqlen, n_head, embed_dim].
"""
return jax.vmap(scaled_dot_product_attention)(Q, K, V)
def generate_data(key: int, batch_size: int, seq_len: int, embed_dim: int, num_heads: int) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, Optional[jnp.ndarray]]:
"""Utility function to generate dummy data for attention mechanism."""
scale = 1.0 / math.sqrt(embed_dim)
q = jnp.array(jax.random.normal(jax.random.PRNGKey(key), (batch_size, num_heads, seq_len, embed_dim)) * scale)
k = jnp.array(jax.random.normal(jax.random.PRNGKey(key + 1), (batch_size, num_heads, seq_len, embed_dim)) * scale)
v = jnp.array(jax.random.normal(jax.random.PRNGKey(key + 2), (batch_size, num_heads, seq_len, embed_dim)) * scale)
return q, k, v
@jax.jit
def flash_attn(Q, K, V):
return flash_attention(Q, K, V)
def test_performance():
n: int = 16
batch_size, seq_len, embed_dim, num_heads = 32, 512, 128, 8
data = [generate_data(i, batch_size, seq_len, embed_dim, num_heads) for i in range(n)]
# warmup
_ = (
vanilla_attn(*data[0]).block_until_ready(),
flash_attn(*data[0]).block_until_ready(),
)
# Testing Equinox's MHSA
start_time = time.time()
for i in range(n):
q, k, v = data[i]
q, k, v = q.block_until_ready(), k.block_until_ready(), v.block_until_ready()
_ = vanilla_attn(q, k, v).block_until_ready()
equinox_time = time.time() - start_time
# Testing Pallas's flash_attention
start_time = time.time()
for i in range(n):
q, k, v = data[i]
q, k, v = q.block_until_ready(), k.block_until_ready(), v.block_until_ready()
_ = flash_attn(q, k, v).block_until_ready()
pallas_time = time.time() - start_time
print(f"Average Naive MHSA time: {equinox_time / n:.5f}s")
print(f"Average Pallas flash_attention time: {pallas_time / n:.5f}s")
if __name__ == "__main__":
test_performance()
Oh nice! I think your benchmark script looks correct to me. Indeed I suspect XLA is pattern-matching on attention, then.
hmm... maybe its worth testing with eqx's MultiHeadAttention to verify its definitely using the flash kernel then
TLDR: you need to use pallas splash attention kernel to save memory and boost speeds on TPUs (no other method worked for me)
i've looked into this topic heavily (on TPUs >= v4), while training LLMs. there are mainly 4 ways to implement attention in jax + equinox.
- jax's attention (https://docs.jax.dev/en/latest/_autosummary/jax.nn.dot_product_attention.html)
- equinox's attention (https://docs.kidger.site/equinox/api/nn/attention/)
- pallas flash attention kernel (https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py)
- pallas splash attention kernel (https://github.com/jax-ml/jax/tree/main/jax/experimental/pallas/ops/tpu/splash_attention)
[speed] 4 (fastest) > 1 > 2 > 3 (slowest)
[memory usage] 1 = 2 > 4 (not sure about 3, ditched it since it was so slow)
so above comments about XLA pattern matching on attention to use flash attention is incorrect. i have looked into the profiles, and equinox's attention implementation nor jax's implementation is automatically compiled into flash attention (there are big blocks that correspond to QK^T in the HLO, thus no memory is saved). you MUST use a custom pallas kernel to get the memory and speed boosts.
talking about kernels, i think the flash attention kernel is broken on TPUs (or at least that's my experience. this also matches the benchmark ran in above comment).
I suspect Googlers are also using splash attention over it as MaxText's flash attention implementation was written with splash attention (so i guess splash attn == flash attn, but written for better compatibility with TPUs?).
(https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/layers/attentions.py. see tpu_flash_attention)
recommend using jax.nn.dot_prod_attn if you do not want to deal with kernels. but if you need optimal perf (or are running out of memory) i highly recommend looking into the pallas implementation of splash attention.
i was surprised this topic was not discussed thoroughly anywhere (i looked into all the issues section of jax, equinox, etc.) p.s.) thank you for the awesome equinox library :)