flash-attention
flash-attention copied to clipboard
Why do the output results of flash attention and muti head attention differ significantly under the same parameters
This is the flash attention code that I have encapsulated `
flash-attention
import math import torch import torch.nn as nn from torch.nn.init import ( xavier_uniform_, constant_, xavier_normal_ ) from torch.nn.functional import linear
from einops import rearrange from mmcv.runner import auto_fp16 from mmcv.runner.base_module import BaseModule from mmcv.cnn.bricks.registry import ATTENTION
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func from flash_attn import flash_attn_func from flash_attn.modules.mha import FlashSelfAttention from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis
def _in_projection_packed(q, k, v, w, b=None): w_q, w_k, w_v = w.chunk(3) # print(w.device, w_q.device) if b is None: b_q = b_k = b_v = None else: b_q, b_k, b_v = b.chunk(3) return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
@ATTENTION.register_module() class FlashAttention(BaseModule): """Implement the scaled dot product attention with softmax. Arguments --------- softmax_scale: The temperature to use for the softmax attention. (default: 1/sqrt(d_keys) where d_keys is computed at runtime) attention_dropout: The dropout rate to apply to the attention (default: 0.1) --------- out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, else (B, S, H, D). """
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
super().__init__()
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
self.fp16_enabled = True
self.flash_self_attn = FlashSelfAttention(attention_dropout=attention_dropout)
@auto_fp16(apply_to=('q', 'kv'), out_fp32=False)
def forward(self, q, kv,
causal=False):
# key_padding_mask=None):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, T, H, D)
kv: The tensor containing the key, and value. (B, S, 2, H, D)
key_padding_mask: a bool tensor of shape (B, S)
"""
assert q.dtype in [torch.float16, torch.bfloat16] and kv.dtype in [torch.float16, torch.bfloat16]
assert q.is_cuda and kv.is_cuda
assert q.shape[0] == kv.shape[0] and q.shape[-2] == kv.shape[-2] and q.shape[-1] == kv.shape[-1]
q = q.unsqueeze(2)
assert q.shape[1] == kv.shape[1]
print(q.shape, kv.shape)
qkv = torch.cat([q, kv], dim=2)
# print(qkv.shape)
result = self.flash_self_attn(qkv, causal=causal)
return result
# batch_size = q.shape[0]
# seqlen_q, seqlen_k = q.shape[1], kv.shape[1]
# # if key_padding_mask is None:
# q, kv = rearrange(q, 'b s ... -> (b s) ...'), rearrange(kv, 'b s ... -> (b s) ...')
# max_sq, max_sk = seqlen_q, seqlen_k
# cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
# device=q.device)
# cu_seqlens_k = torch.arange(0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32,
# device=kv.device)
# print(q.shape, kv.shape)
# output = flash_attn_varlen_kvpacked_func(
# q, kv, cu_seqlens_q, cu_seqlens_k, max_sq, max_sk,
# self.dropout_p if self.training else 0.0,
# softmax_scale=self.softmax_scale, causal=causal
# )
# output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
# else:
# nheads = kv.shape[-2]
# q = rearrange(q, 'b s ... -> (b s) ...')
# max_sq = seqlen_q
# cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32,
# device=q.device)
# x = rearrange(kv, 'b s two h d -> b s (two h d)')
# x_unpad, indices, cu_seqlens_k, max_sk = unpad_input(x, key_padding_mask)
# x_unpad = rearrange(x_unpad, 'nnz (two h d) -> nnz two h d', two=2, h=nheads)
# output_unpad = flash_attn_varlen_kvpacked_func(
# q, x_unpad, cu_seqlens_q, cu_seqlens_k, max_sq, max_sk,
# self.dropout_p if self.training else 0.0,
# softmax_scale=self.softmax_scale, causal=causal
# )
# output = rearrange(output_unpad, '(b s) ... -> b s ...', b=batch_size)
return output, None
@ATTENTION.register_module class FlashMHA(BaseModule):
def __init__(self, embed_dim, num_heads, bias=True, batch_first=True, attention_dropout=0.0,
causal=False, device=None, dtype=None, **kwargs) -> None:
assert batch_first
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.embed_dim = embed_dim
self.causal = causal
self.bias = bias
if not self.training:
attention_dropout = 0.
# self.linear_q = nn.Linear(embed_dim, embed_dim, bias=bias)
# self.linear_k = nn.Linear(embed_dim, embed_dim, bias=bias)
# self.linear_v = nn.Linear(embed_dim, embed_dim, bias=bias)
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim))).to(factory_kwargs['dtype']).to(factory_kwargs['device'])
if bias:
self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)).to(factory_kwargs['dtype']).to(factory_kwargs['device'])
else:
self.register_parameter('in_proj_bias', None)
self.inner_attn = FlashAttention(attention_dropout=attention_dropout, **factory_kwargs)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, dtype=dtype)
self._reset_parameters()
def _reset_parameters(self) -> None:
xavier_uniform_(self.in_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.)
constant_(self.out_proj.bias, 0.)
# xavier_uniform_(self.linear_q.weight)
# xavier_uniform_(self.linear_k.weight)
# xavier_uniform_(self.linear_v.weight)
# if self.bias is not None:
# constant_(self.linear_q.bias, 0.)
# constant_(self.linear_k.bias, 0.)
# constant_(self.linear_v.bias, 0.)
# constant_(self.out_proj.bias, 0.)
# def _in_projection_packed(self, q, k, v):
# return self.linear_q(q), self.linear_k(k), self.linear_v(v)
def forward(self, q, k, v):
"""x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
key_padding_mask: bool tensor of shape (batch, seqlen)
"""
# q, k, v = self.Wq(q), self.Wk(k), self.Wv(v)
# print(self.in_proj_weight.device, self.in_proj_bias.device)
print(q.shape, k.shape, v.shape)
q, k, v = _in_projection_packed(q, k, v, self.in_proj_weight, self.in_proj_bias)
# print(q.dtype, k.dtype, v.dtype)
# q, k, v = self._in_projection_packed(q, k, v)
q = rearrange(q, 'b s (h d) -> b s h d', h=self.num_heads)
k = rearrange(k, 'b s (h d) -> b s h d', h=self.num_heads).unsqueeze(2)
v = rearrange(v, 'b s (h d) -> b s h d', h=self.num_heads).unsqueeze(2)
kv = torch.cat([k, v], dim=2)
# print(q.dtype, kv.dtype)
context = self.inner_attn(q, kv, causal=self.causal)
return self.out_proj(rearrange(context, 'b s h d -> b s (h d)'))
`
This is the test code ` batch_size = 128 nheads = 8 seqlen = 1024 n = 512 # d = n // nheads dropout_p = 0 # causal = False dtype = torch.float16 device = 'cuda'
q = torch.randn(batch_size, seqlen, n, device='cuda', dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen, n, device='cuda', dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen, n, device='cuda', dtype=dtype, requires_grad=True)
# q = torch.randn(batch_size, n, nheads, d, device='cuda', dtype=dtype)
# k = torch.randn(batch_size, n, nheads, d, device='cuda', dtype=dtype)
# v = torch.randn(batch_size, n, nheads, d, device='cuda', dtype=dtype)
# key_padding_mask = k.new_ones(batch_size, seqlen)
model = FlashMHA(embed_dim=n, num_heads=nheads, device=device, dtype=dtype).to(device)
# model = MHA(n, nheads, device=device, dtype=dtype).to(device)
model.eval()
for key, params in model.state_dict().items():
print(key, params.size())
# k = k.unsqueeze(2)
# v = v.unsqueeze(2)
# kv = torch.cat([k, v], dim=2)
result = model(q, k, v)
# print(result)
print(result.shape)
model2 = nn.MultiheadAttention(n, nheads, device=device, dtype=dtype).to(device)
print(model2)
model2.eval()
for key, params in model2.state_dict().items():
print(key, params.size())
model.in_proj_weight.data.fill_(1e-3)
model.in_proj_weight.data[:n, :].fill_(1e-4)
model.in_proj_weight.data[-n:, :].fill_(1e-2)
model.in_proj_bias.data.fill_(0)
# model.linear_q.weight.data.fill_(1e-3)
# model.linear_q.bias.data.fill_(0)
# model.linear_k.weight.data.fill_(1e-4)
# model.linear_k.bias.data.fill_(0)
# model.linear_v.weight.data.fill_(1e-2)
# model.linear_v.bias.data.fill_(0)
# model.Wqkv.weight.data.fill_(1e-3)
# model.Wqkv.bias.data.fill_(0)
model.out_proj.weight.data.fill_(1e-2)
model.out_proj.bias.data.fill_(0)
model2.in_proj_weight.data.fill_(1e-3)
model2.in_proj_weight.data[:n, :].fill_(1e-4)
model2.in_proj_weight.data[-n:, :].fill_(1e-2)
model2.in_proj_bias.data.fill_(0)
model2.out_proj.weight.data.fill_(1e-2)
model2.out_proj.bias.data.fill_(0)
print(q.size(), k.size(), v.size())
out1 = model(q, k, v)
out2 = model2(q.permute(1, 0, 2), k.permute(1, 0, 2), v.permute(1, 0, 2))[0].permute(1, 0, 2)
print("out shape:", out1.size(), out2.size())
# print(out1.dtype, out2.dtype)
print("different: ", (out1 - out2).flatten().abs().max())
`
Comparison results under the same input
Comparison results under the same input
What's the difference?
The right comparison is (flashattn in fp16 - reference implementation in fp32) vs (rerefnece implementation in fp16 - reference in fp32)
I met the same problem, did you resolve it?