FlagAttention icon indicating copy to clipboard operation
FlagAttention copied to clipboard

Hypno/add bias

Open hypnopump opened this issue 1 year ago • 6 comments

Adds bias to attention.

bias

Many tests fail for me (that's why i'm adding draft PR), especially the BTHD and longer sequence ones (my GPU is 12Gb) but manual pytorch tests seem to match.

It can be tested with :

# Check same output as torch

import torch as th
from flag_attn import flash_attention

# B, H, T, D = 2, 16, 8192, 128
B, H, T, D = 1, 1, 2048, 16


th.manual_seed(17)
th.cuda.manual_seed(17)
q = th.randn((B, H, T, D), dtype=th.float16, device="cuda:0").requires_grad_()
k = th.randn((B, H, T, D), dtype=th.float16, device="cuda:0").requires_grad_()
v = th.randn((B, H, T, D), dtype=th.float16, device="cuda:0").requires_grad_()
bias = th.randn((B, H, T, T), dtype=th.float16, device="cuda:0").requires_grad_()
go = th.randn((B, H, T, D), dtype=th.float16, device="cuda:0")
onobias = flash_attention(q, k, v, causal=False)
o_nobias = (q @ k.transpose(-1, -2) / q.shape[-1]**0.5).softmax(dim=-1) @ v
o_th_nobias = th.nn.functional.scaled_dot_product_attention(q, k, v)

o = flash_attention(q, k, v, bias, causal=False)
o_ = (q @ k.transpose(-1, -2) / q.shape[-1]**0.5 + bias).softmax(dim=-1) @ v
o_th = th.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=bias)
assert (o - o_th).amax() < 1e-3

gq_th, gk_th, gv_th, gbias_th = th.autograd.grad(
    o_th, (q, k, v, bias), go
)

gq, gk, gv, gbias = th.autograd.grad(
    o, (q, k, v, bias), go
)
assert (gbias - gbias_th).amax() < 1e-3

hypnopump avatar Dec 14 '23 12:12 hypnopump