flash-attention
flash-attention copied to clipboard
There was a strange computation error between standard attention and flash-attention2
import time
import torch
from torch.nn import functional as F
from flash_attn import flash_attn_func
from einops import rearrange
import math
def standard_attention(query_layer, key_layer, value_layer, attention_mask,scaling_attention_score=True):
if scaling_attention_score:
query_layer = query_layer / math.sqrt(query_layer.shape[-1])
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = torch.mul(attention_scores, attention_mask) + \
torch.finfo(query_layer.dtype).min * (1.0 - attention_mask)
attention_probs = F.softmax(attention_scores, dim=-1)
context_layer = torch.matmul(attention_probs, value_layer)
return context_layer
def flash_attention(q, k, v, causal=False):
o = flash_attn_func(q, k, v, causal=causal)
return o
def test(func_name, q, k, v, *args, **kwargs):
if func_name in ["standard_attention", "pytorch_func"]:
q = rearrange(q, "a b c d -> a c b d")
k = rearrange(k, "a b c d -> a c b d")
v = rearrange(v, "a b c d -> a c b d")
o = globals()[func_name](q, k, v, *args, **kwargs)
if func_name in ["standard_attention", "pytorch_func"]:
o = rearrange(o, "a c b d -> a b c d")
return o
if __name__ == "__main__":
batch_size = 1
seq_len = 4096
num_head = 32
hidden_units = 128
f_attn_mask = torch.ones(batch_size, num_head, seq_len, seq_len, dtype=torch.float16, device="cuda")
f_attn_mask.tril_()
query = torch.rand((batch_size, seq_len, num_head, hidden_units), dtype=torch.float16, device="cuda")
key = torch.rand((batch_size, seq_len, num_head, hidden_units), dtype=torch.float16, device="cuda")
value = torch.rand((batch_size, seq_len, num_head, hidden_units), dtype=torch.float16, device="cuda")
o = test("standard_attention", query, key, value, attention_mask=f_attn_mask, scaling_attention_score=True)
fa_o = test("flash_attention", query, key, value, causal=True)
I have observed a peculiar computation error between standard attention and flash-attention2. When i run the code above, the maximum error value between the two is 0.0005, but once i multiplied the query-key and value by 100, the maximum error value became 64.8125. Is this the expected numerical error?
The right thing to compare to is standard attention in fp32. In this case FlashAttention is actually more accurate than the standard implementation in fp16:
torch.manual_seed(0)
batch_size = 1
seq_len = 4096
num_head = 32
hidden_units = 128
f_attn_mask = torch.ones(batch_size, num_head, seq_len, seq_len, dtype=torch.float16, device="cuda")
f_attn_mask.tril_()
query = torch.rand((batch_size, seq_len, num_head, hidden_units), dtype=torch.float16, device="cuda") * 100
key = torch.rand((batch_size, seq_len, num_head, hidden_units), dtype=torch.float16, device="cuda") * 100
value = torch.rand((batch_size, seq_len, num_head, hidden_units), dtype=torch.float16, device="cuda") * 100
o = test("standard_attention", query, key, value, attention_mask=f_attn_mask, scaling_attention_score=True)
fa_o = test("flash_attention", query, key, value, causal=True)
o_ref = test("standard_attention", query.float(), key.float(), value.float(), attention_mask=f_attn_mask, scaling_attention_score=True)
print((o - o_ref).abs().max()) # 65.7
print((fa_o - o_ref).abs().max()) # 32.6
The numerical error here is probably unavoidable because some computation is done fp16, which has limited numerical precision.
Thank you for your reply. I have another question. I am working on enabling prefix attention mask support for flash-attention. I have already completed the forward propagation part, but I have some doubts about the backward propagation. What does this code mean? In what scenarios would it execute this logic?https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_bwd_kernel.h#L647
leizhao1234
hello. 、 I had the same problem with bf16, do you have a solution?
Are you referring to the bf16 computation error issue?
Are you referring to the bf16 computation error issue? yes,once multiplied the f_attn_mask by 100, the same computation error issue.
This paper also talks about instabilities of flash attention: https://arxiv.org/pdf/2405.02803v1