boltz icon indicating copy to clipboard operation
boltz copied to clipboard

Thoughts on using `flexattention` for AttentionWithPairBias

Open shenoynikhil opened this issue 9 months ago • 1 comments

I currently see a custom implementation for AttentionWithPairBias that does the biasing of the attention scores with the pair-wise conditioning tensors. FlexAttention allows a neat way to get nice speedups See link. I just tested a simple example and can see an order of magnitude speedup. Curious if you guys have tried this and have any opinion on this. Additionally, this speedup comes with requiring compiling the graph so a second question is how often would this compilation be needed.

import timeit
import torch
from torch import Tensor
from torch.nn.attention.flex_attention import flex_attention

# Time the flex attention implementation
B, H, N, D = 16, 8, 32, 128
device = 'cuda'
query = torch.randn(B, H, N, D, device=device)
key = torch.randn(B, H, N, D, device=device)
value = torch.randn(B, H, N, D, device=device)
bias = torch.randn(B, H, N, N, device=device)

def score_mod(score, b, h, q_idx, kv_idx):
    return score + bias[b][h][q_idx][kv_idx]

flex_attention = torch.compile(flex_attention)
for _ in range(10): # compile and warmup
    flex_attention(query, key, value, score_mod=score_mod)

def custom_attention_with_bias(q, k, v, bias):
    attn = torch.einsum("bhid,bhjd->bhij", q, k)
    attn = attn / (D ** 0.5) + bias.float()
    attn = torch.softmax(attn, dim=-1)
    o = torch.einsum("bhij,bhjd->bhid", attn, v)
    return o

# Time the custom attention implementation
custom_time = timeit.timeit(
    lambda: custom_attention_with_bias(query, key, value, bias),
    number=100
) / 100


flex_time = timeit.timeit(
    lambda: flex_attention(query, key, value, score_mod=score_mod),
    number=100
) / 100

print(f"Flex Attention time: {flex_time*1000:.2f} ms")
print(f"Custom Attention time: {custom_time*1000:.2f} ms")

o1 = flex_attention(query, key, value, score_mod=score_mod)
o2 = custom_attention_with_bias(query, key, value, bias)

torch.allclose(o1, o2, atol=1e-5)

Outputs

Flex Attention time: 0.13 ms
Custom Attention time: 1.20 ms
True

shenoynikhil avatar May 31 '25 23:05 shenoynikhil

So glad to see flex attention helps!

For recompilation, when bias tensor changes its value, we don't need recompilation. If the shape changed, we will need a recompilation.

BoyuanFeng avatar Jun 02 '25 18:06 BoyuanFeng