flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

There was a strange computation error between standard attention and flash-attention2

Open leizhao1234 opened this issue 2 years ago • 6 comments
trafficstars

import time
import torch
from torch.nn import functional as F
from flash_attn import flash_attn_func
from einops import rearrange
import math

def standard_attention(query_layer, key_layer, value_layer, attention_mask,scaling_attention_score=True):
    if scaling_attention_score:
        query_layer = query_layer / math.sqrt(query_layer.shape[-1])
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

    attention_scores = torch.mul(attention_scores, attention_mask) + \
                           torch.finfo(query_layer.dtype).min * (1.0 - attention_mask)

    attention_probs = F.softmax(attention_scores, dim=-1)


    context_layer = torch.matmul(attention_probs, value_layer)
    return context_layer

def flash_attention(q, k, v, causal=False):
    o = flash_attn_func(q, k, v, causal=causal)
    return o

def test(func_name, q, k, v, *args, **kwargs):
    if func_name in ["standard_attention", "pytorch_func"]:
        q = rearrange(q, "a b c d -> a c b d")
        k = rearrange(k, "a b c d -> a c b d")
        v = rearrange(v, "a b c d -> a c b d")

    o = globals()[func_name](q, k, v, *args, **kwargs)
    if func_name in ["standard_attention", "pytorch_func"]:
        o = rearrange(o, "a c b d -> a b c d")
    return o

if __name__ == "__main__":
batch_size = 1
seq_len = 4096
num_head = 32
hidden_units = 128
f_attn_mask = torch.ones(batch_size, num_head, seq_len, seq_len, dtype=torch.float16, device="cuda")
f_attn_mask.tril_()
query = torch.rand((batch_size, seq_len, num_head, hidden_units), dtype=torch.float16, device="cuda")
key = torch.rand((batch_size, seq_len, num_head, hidden_units), dtype=torch.float16, device="cuda")
value =  torch.rand((batch_size, seq_len, num_head, hidden_units), dtype=torch.float16, device="cuda")
o = test("standard_attention", query, key, value, attention_mask=f_attn_mask, scaling_attention_score=True)
fa_o = test("flash_attention", query, key, value, causal=True)

I have observed a peculiar computation error between standard attention and flash-attention2. When i run the code above, the maximum error value between the two is 0.0005, but once i multiplied the query-key and value by 100, the maximum error value became 64.8125. Is this the expected numerical error?

leizhao1234 avatar Aug 08 '23 06:08 leizhao1234

The right thing to compare to is standard attention in fp32. In this case FlashAttention is actually more accurate than the standard implementation in fp16:

torch.manual_seed(0)
batch_size = 1
seq_len = 4096
num_head = 32
hidden_units = 128
f_attn_mask = torch.ones(batch_size, num_head, seq_len, seq_len, dtype=torch.float16, device="cuda")
f_attn_mask.tril_()
query = torch.rand((batch_size, seq_len, num_head, hidden_units), dtype=torch.float16, device="cuda") * 100
key = torch.rand((batch_size, seq_len, num_head, hidden_units), dtype=torch.float16, device="cuda") * 100
value =  torch.rand((batch_size, seq_len, num_head, hidden_units), dtype=torch.float16, device="cuda") * 100
o = test("standard_attention", query, key, value, attention_mask=f_attn_mask, scaling_attention_score=True)
fa_o = test("flash_attention", query, key, value, causal=True)
o_ref = test("standard_attention", query.float(), key.float(), value.float(), attention_mask=f_attn_mask, scaling_attention_score=True)
print((o - o_ref).abs().max())  # 65.7
print((fa_o - o_ref).abs().max())  # 32.6

The numerical error here is probably unavoidable because some computation is done fp16, which has limited numerical precision.

tridao avatar Aug 09 '23 06:08 tridao

Thank you for your reply. I have another question. I am working on enabling prefix attention mask support for flash-attention. I have already completed the forward propagation part, but I have some doubts about the backward propagation. What does this code mean? In what scenarios would it execute this logic?https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/flash_bwd_kernel.h#L647

leizhao1234 avatar Aug 09 '23 13:08 leizhao1234

leizhao1234

hello. 、 I had the same problem with bf16, do you have a solution?

shiqingzhangCSU avatar Aug 11 '23 06:08 shiqingzhangCSU

Are you referring to the bf16 computation error issue?

leizhao1234 avatar Aug 11 '23 07:08 leizhao1234

Are you referring to the bf16 computation error issue? yes,once multiplied the f_attn_mask by 100, the same computation error issue.

shiqingzhangCSU avatar Aug 11 '23 07:08 shiqingzhangCSU

This paper also talks about instabilities of flash attention: https://arxiv.org/pdf/2405.02803v1

jinhuaca avatar Jul 18 '24 22:07 jinhuaca