flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Why do the output results of flash attention and muti head attention differ significantly under the same parameters

Open Dominic23331 opened this issue 1 year ago • 1 comments

This is the flash attention code that I have encapsulated `

flash-attention

import math import torch import torch.nn as nn from torch.nn.init import ( xavier_uniform_, constant_, xavier_normal_ ) from torch.nn.functional import linear

from einops import rearrange from mmcv.runner import auto_fp16 from mmcv.runner.base_module import BaseModule from mmcv.cnn.bricks.registry import ATTENTION

from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func from flash_attn import flash_attn_func from flash_attn.modules.mha import FlashSelfAttention from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis

def _in_projection_packed(q, k, v, w, b=None): w_q, w_k, w_v = w.chunk(3) # print(w.device, w_q.device) if b is None: b_q = b_k = b_v = None else: b_q, b_k, b_v = b.chunk(3) return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)

@ATTENTION.register_module() class FlashAttention(BaseModule): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.1) --------- out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, else (B, S, H, D). """

def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
    super().__init__()
    self.softmax_scale = softmax_scale
    self.dropout_p = attention_dropout
    self.fp16_enabled = True
    self.flash_self_attn = FlashSelfAttention(attention_dropout=attention_dropout)

@auto_fp16(apply_to=('q', 'kv'), out_fp32=False)
def forward(self, q, kv,
            causal=False):
            # key_padding_mask=None):
    """Implements the multihead softmax attention.
    Arguments
    ---------
        q: The tensor containing the query. (B, T, H, D)
        kv: The tensor containing the key, and value. (B, S, 2, H, D)
        key_padding_mask: a bool tensor of shape (B, S)
    """
    assert q.dtype in [torch.float16, torch.bfloat16] and kv.dtype in [torch.float16, torch.bfloat16]
    assert q.is_cuda and kv.is_cuda
    assert q.shape[0] == kv.shape[0] and q.shape[-2] == kv.shape[-2] and q.shape[-1] == kv.shape[-1]
    q = q.unsqueeze(2)
    assert q.shape[1] == kv.shape[1]
    print(q.shape, kv.shape)
    qkv = torch.cat([q, kv], dim=2)
    # print(qkv.shape)
    result = self.flash_self_attn(qkv, causal=causal)

    return result

    # batch_size = q.shape[0]
    # seqlen_q, seqlen_k = q.shape[1], kv.shape[1]
    # # if key_padding_mask is None:
    # q, kv = rearrange(q, 'b s ... -> (b s) ...'), rearrange(kv, 'b s ... -> (b s) ...')
    # max_sq, max_sk = seqlen_q, seqlen_k
    # cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
    #                             device=q.device)
    # cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
    #                             device=kv.device)
    # print(q.shape, kv.shape)
    # output = flash_attn_varlen_kvpacked_func(
    #     q, kv, cu_seqlens_q, cu_seqlens_k, max_sq, max_sk,
    #     self.dropout_p if self.training else 0.0,
    #     softmax_scale=self.softmax_scale, causal=causal
    # )
    # output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
    # else:
    #     nheads = kv.shape[-2]
    #     q = rearrange(q, 'b s ... -> (b s) ...')
    #     max_sq = seqlen_q
    #     cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
    #                                 device=q.device)
    #     x = rearrange(kv, 'b s two h d -> b s (two h d)')
    #     x_unpad, indices, cu_seqlens_k, max_sk = unpad_input(x, key_padding_mask)
    #     x_unpad = rearrange(x_unpad, 'nnz (two h d) -> nnz two h d', two=2, h=nheads)
    #     output_unpad = flash_attn_varlen_kvpacked_func(
    #         q, x_unpad, cu_seqlens_q, cu_seqlens_k, max_sq, max_sk,
    #         self.dropout_p if self.training else 0.0,
    #         softmax_scale=self.softmax_scale, causal=causal
    #     )
    #     output = rearrange(output_unpad, '(b s) ... -> b s ...', b=batch_size)

    return output, None

@ATTENTION.register_module class FlashMHA(BaseModule):

def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
             causal=False, device=None, dtype=None, **kwargs) -> None:
    assert batch_first
    factory_kwargs = {'device': device, 'dtype': dtype}
    super().__init__()

    self.embed_dim = embed_dim
    self.causal = causal
    self.bias = bias

    if not self.training:
        attention_dropout = 0.

    # self.linear_q = nn.Linear(embed_dim, embed_dim, bias=bias)
    # self.linear_k = nn.Linear(embed_dim, embed_dim, bias=bias)
    # self.linear_v = nn.Linear(embed_dim, embed_dim, bias=bias)

    self.num_heads = num_heads
    assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
    self.head_dim = self.embed_dim // num_heads
    assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"

    self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim))).to(factory_kwargs['dtype']).to(factory_kwargs['device'])
    if bias:
        self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)).to(factory_kwargs['dtype']).to(factory_kwargs['device'])
    else:
        self.register_parameter('in_proj_bias', None)
    self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
    self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, dtype=dtype)
    self._reset_parameters()

def _reset_parameters(self) -> None:
    xavier_uniform_(self.in_proj_weight)
    if self.in_proj_bias is not None:
        constant_(self.in_proj_bias, 0.)
        constant_(self.out_proj.bias, 0.)
    # xavier_uniform_(self.linear_q.weight)
    # xavier_uniform_(self.linear_k.weight)
    # xavier_uniform_(self.linear_v.weight)
    # if self.bias is not None:
    #     constant_(self.linear_q.bias, 0.)
    #     constant_(self.linear_k.bias, 0.)
    #     constant_(self.linear_v.bias, 0.)
    #     constant_(self.out_proj.bias, 0.)

# def _in_projection_packed(self, q, k, v):
#     return self.linear_q(q), self.linear_k(k), self.linear_v(v)

def forward(self, q, k, v):
    """x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
    key_padding_mask: bool tensor of shape (batch, seqlen)
    """
    # q, k, v = self.Wq(q), self.Wk(k), self.Wv(v)
    # print(self.in_proj_weight.device, self.in_proj_bias.device)
    print(q.shape, k.shape, v.shape)
    q, k, v = _in_projection_packed(q, k, v, self.in_proj_weight, self.in_proj_bias)
    # print(q.dtype, k.dtype, v.dtype)
    # q, k, v = self._in_projection_packed(q, k, v)
    q = rearrange(q, 'b s (h d) -> b s h d', h=self.num_heads)
    k = rearrange(k, 'b s (h d) -> b s h d', h=self.num_heads).unsqueeze(2)
    v = rearrange(v, 'b s (h d) -> b s h d', h=self.num_heads).unsqueeze(2)
    kv = torch.cat([k, v], dim=2)
    # print(q.dtype, kv.dtype)

    context = self.inner_attn(q, kv, causal=self.causal)

    return self.out_proj(rearrange(context, 'b s h d -> b s (h d)'))

`

This is the test code ` batch_size = 128 nheads = 8 seqlen = 1024 n = 512 # d = n // nheads dropout_p = 0 # causal = False dtype = torch.float16 device = 'cuda'

q = torch.randn(batch_size, seqlen, n, device='cuda', dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen, n, device='cuda', dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen, n, device='cuda', dtype=dtype, requires_grad=True)
# q = torch.randn(batch_size, n, nheads, d, device='cuda', dtype=dtype)
# k = torch.randn(batch_size, n, nheads, d, device='cuda', dtype=dtype)
# v = torch.randn(batch_size, n, nheads, d, device='cuda', dtype=dtype)

# key_padding_mask = k.new_ones(batch_size, seqlen)
model = FlashMHA(embed_dim=n, num_heads=nheads, device=device, dtype=dtype).to(device)
# model = MHA(n, nheads, device=device, dtype=dtype).to(device)
model.eval()
for key, params in model.state_dict().items():
    print(key, params.size())

# k = k.unsqueeze(2)
# v = v.unsqueeze(2)
# kv = torch.cat([k, v], dim=2)

result = model(q, k, v)
# print(result)
print(result.shape)



model2 = nn.MultiheadAttention(n, nheads, device=device, dtype=dtype).to(device)
print(model2)
model2.eval()
for key, params in model2.state_dict().items():
    print(key, params.size())

model.in_proj_weight.data.fill_(1e-3)
model.in_proj_weight.data[:n, :].fill_(1e-4)
model.in_proj_weight.data[-n:, :].fill_(1e-2)
model.in_proj_bias.data.fill_(0)
# model.linear_q.weight.data.fill_(1e-3)
# model.linear_q.bias.data.fill_(0)
# model.linear_k.weight.data.fill_(1e-4)
# model.linear_k.bias.data.fill_(0)
# model.linear_v.weight.data.fill_(1e-2)
# model.linear_v.bias.data.fill_(0)
# model.Wqkv.weight.data.fill_(1e-3)
# model.Wqkv.bias.data.fill_(0)
model.out_proj.weight.data.fill_(1e-2)
model.out_proj.bias.data.fill_(0)

model2.in_proj_weight.data.fill_(1e-3)
model2.in_proj_weight.data[:n, :].fill_(1e-4)
model2.in_proj_weight.data[-n:, :].fill_(1e-2)
model2.in_proj_bias.data.fill_(0)
model2.out_proj.weight.data.fill_(1e-2)
model2.out_proj.bias.data.fill_(0)
print(q.size(), k.size(), v.size())
out1 = model(q, k, v)
out2 = model2(q.permute(1, 0, 2), k.permute(1, 0, 2), v.permute(1, 0, 2))[0].permute(1, 0, 2)
print("out shape:", out1.size(), out2.size())

# print(out1.dtype, out2.dtype)
print("different: ", (out1 - out2).flatten().abs().max())

`

Comparison results under the same input Comparison results under the same input

Dominic23331 avatar Jul 10 '24 12:07 Dominic23331

What's the difference?

The right comparison is (flashattn in fp16 - reference implementation in fp32) vs (rerefnece implementation in fp16 - reference in fp32)

tridao avatar Jul 10 '24 17:07 tridao

I met the same problem, did you resolve it?

MyraYu2022 avatar Dec 24 '24 07:12 MyraYu2022