FlagAttention
FlagAttention copied to clipboard
Hypno/add bias
Adds bias to attention.
Many tests fail for me (that's why i'm adding draft PR), especially the BTHD and longer sequence ones (my GPU is 12Gb) but manual pytorch tests seem to match.
It can be tested with :
# Check same output as torch
import torch as th
from flag_attn import flash_attention
# B, H, T, D = 2, 16, 8192, 128
B, H, T, D = 1, 1, 2048, 16
th.manual_seed(17)
th.cuda.manual_seed(17)
q = th.randn((B, H, T, D), dtype=th.float16, device="cuda:0").requires_grad_()
k = th.randn((B, H, T, D), dtype=th.float16, device="cuda:0").requires_grad_()
v = th.randn((B, H, T, D), dtype=th.float16, device="cuda:0").requires_grad_()
bias = th.randn((B, H, T, T), dtype=th.float16, device="cuda:0").requires_grad_()
go = th.randn((B, H, T, D), dtype=th.float16, device="cuda:0")
onobias = flash_attention(q, k, v, causal=False)
o_nobias = (q @ k.transpose(-1, -2) / q.shape[-1]**0.5).softmax(dim=-1) @ v
o_th_nobias = th.nn.functional.scaled_dot_product_attention(q, k, v)
o = flash_attention(q, k, v, bias, causal=False)
o_ = (q @ k.transpose(-1, -2) / q.shape[-1]**0.5 + bias).softmax(dim=-1) @ v
o_th = th.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias)
assert (o - o_th).amax() < 1e-3
gq_th, gk_th, gv_th, gbias_th = th.autograd.grad(
o_th, (q, k, v, bias), go
)
gq, gk, gv, gbias = th.autograd.grad(
o, (q, k, v, bias), go
)
assert (gbias - gbias_th).amax() < 1e-3