flash-attention
flash-attention copied to clipboard
Flash Attention 2 Output not Equal to PyTorch scaled_dot_product_attention in MusicGen Inference
I swapped out the Torch attention function for Flash Attention 2 in the MusicGen project here like so:
if self.memory_efficient:
p = self.dropout if self.training else 0
if _efficient_attention_backend == 'torch':
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, is_causal=attn_mask is not None, dropout_p=p)
y = flash_attn_func(q, k, v, causal=attn_mask is not None, dropout_p=p)
assert torch.allclose(x, y, atol=1e-5), "flash_attn_func and scaled_dot_product_attention are not equal"
else:
x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
but the resulting tensors were not equal (my assertion halted execution) during inference like so:
import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
model = MusicGen.get_pretrained('melody')
model.set_generation_params(duration=8) # generate 8 seconds.
descriptions = ['happy rock', 'energetic EDM', 'sad jazz']
wav = model.generate(descriptions) # generates 3 samples.
for idx, one_wav in enumerate(wav):
# Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)
Additionally ignoring the above significant divergence which affects model output, I also run into the following issue during inference
Number of heads in key/value must divide number of heads in query
Since num_heads stays at 1 but the k num_heads scales with the number of tokens in the context window. However, I assume this will be solved by the forthcoming inference optimizations mentioned here.
The right comparison is (FlashAttention in fp16/bf16 - standard attention in fp32) vs (standard attention in fp16/bf16 - standard attention in fp32).
What are the two differences ^^^ in your use case?
Sorry I'm not quite sure what you're hinting at. I'm comparing FlashAttention 2 FP16 vs Torch's Memory Efficient SDP Attention FP16.
I think for scaled_dot_product_attention you need to transpose the second and the third axes of qkv. It requires a different layout.
query = torch.rand(1, 8, 16, 32, dtype=torch.float16, device="cuda")
key = torch.rand(1, 8, 16, 32, dtype=torch.float16, device="cuda")
value = torch.rand(1, 8, 16, 32, dtype=torch.float16, device="cuda")
is_causal = True
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(enable_math=False):
ref = F.scaled_dot_product_attention(torch.permute(query, [0, 2, 1, 3]),
torch.permute(key, [0, 2, 1, 3]),
torch.permute(value, [0, 2, 1, 3]),
is_causal=is_causal)
ref = torch.permute(ref, [0, 2, 1, 3])
out = flash_attn_func(query, key, value, causal=is_causal)
print(torch.max(torch.abs(out - ref)), torch.mean(torch.abs(out - ref)))
@tridao I've checked the differences you suggest, and they are of the same order as you thought they would.
Still could you explain where the difference between FlashAttention in fp16 and standard attention in fp16 is coming from? Cheers
Floating point operations are not associative. Changing the order of the operations will change the output, up to numerical precision. Example
In [1]: import torch
In [2]: a = torch.randn(1024, dtype=torch.float16, device="cuda")
In [3]: out1 = a + 0.3 - 0.2
In [4]: out2 = a - 0.2 + 0.3
In [5]: (out1 - out2).abs().max()
Out[5]: tensor(0.0020, device='cuda:0', dtype=torch.float16)
import torch.nn.functional as F
B = 1; T = 3; nh = 32; C = 64
# q = torch.arange(B * T * nh * C).reshape([B, T, nh, C]).float();
# k = torch.arange(B * T * nh * C).flip(dims=(-1,)).reshape([B, T, nh, C]).float();
# v= torch.arange(B * T * nh * C).reshape([B, T, nh, C]).float() ** 1.5
q = torch.randn([B, T, nh, C])
k = torch.randn([B, T, nh, C])
v = torch.randn([B, T, nh, C])
out = F.scaled_dot_product_attention(q.permute([0, 2, 1, 3]).cuda(), k.permute([0, 2, 1, 3]).cuda(), v.permute([0, 2, 1, 3]).cuda(), dropout_p=0.0, is_causal=False).permute([0, 2, 1, 3])
import math
def func(q, k, v, use_causal=False):
q, k, v = q.permute([0, 2, 1, 3]), k.permute([0, 2, 3, 1]), v.permute([0, 2, 1, 3])
# attn = (q @ k.transpose([0, 1, 3, 2])) * (1.0 + math.sqrt(k.shape[-1]))
attn = (q @ k) * (1.0 / math.sqrt(k.shape[-1]))
if use_causal:
attn = masked_fill(attn, self.causal_mask[:, :, :T, :T] == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
out = attn @ v # (B, nh, T, T) x (B, nh, T, c) -> (B, nh, T, c)
out = out.permute([0, 2, 1, 3]) # (B, T, nh, c)
return out
out2 = func(q.cuda(), k.cuda(), v.cuda(), use_causal=False)
print((out2 - out).abs().mean()) # about 0.3
I use torch.float32 in unit test and find the difference between normal attention and flash attention is about 0.3. I think the precision is enough, but the difference is large. Could someone tell me why ?
With torch.float32 F.scaled_dot_product_attention does not call FlashAttention (only implemented for fp16 and bf16). You can ask in the Pytorch github.
should I be worried about the difference?