flash-attention icon indicating copy to clipboard operation
flash-attention copied to clipboard

Flash Attention 2 Output not Equal to PyTorch scaled_dot_product_attention in MusicGen Inference

Open zaptrem opened this issue 2 years ago • 8 comments

I swapped out the Torch attention function for Flash Attention 2 in the MusicGen project here like so:

if self.memory_efficient:
    p = self.dropout if self.training else 0
    if _efficient_attention_backend == 'torch':
            
        x = torch.nn.functional.scaled_dot_product_attention(
        q, k, v, is_causal=attn_mask is not None, dropout_p=p)
        y = flash_attn_func(q, k, v, causal=attn_mask is not None, dropout_p=p)

        assert torch.allclose(x, y, atol=1e-5), "flash_attn_func and scaled_dot_product_attention are not equal"
    else:
        x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)

but the resulting tensors were not equal (my assertion halted execution) during inference like so:

import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write

model = MusicGen.get_pretrained('melody')
model.set_generation_params(duration=8)  # generate 8 seconds.
descriptions = ['happy rock', 'energetic EDM', 'sad jazz']
wav = model.generate(descriptions)  # generates 3 samples.

for idx, one_wav in enumerate(wav):
    # Will save under {idx}.wav, with loudness normalization at -14 db LUFS.
    audio_write(f'{idx}', one_wav.cpu(), model.sample_rate, strategy="loudness", loudness_compressor=True)

Additionally ignoring the above significant divergence which affects model output, I also run into the following issue during inference

Number of heads in key/value must divide number of heads in query

Since num_heads stays at 1 but the k num_heads scales with the number of tokens in the context window. However, I assume this will be solved by the forthcoming inference optimizations mentioned here.

zaptrem avatar Jul 26 '23 07:07 zaptrem

The right comparison is (FlashAttention in fp16/bf16 - standard attention in fp32) vs (standard attention in fp16/bf16 - standard attention in fp32).

What are the two differences ^^^ in your use case?

tridao avatar Jul 26 '23 07:07 tridao

Sorry I'm not quite sure what you're hinting at. I'm comparing FlashAttention 2 FP16 vs Torch's Memory Efficient SDP Attention FP16.

zaptrem avatar Jul 26 '23 21:07 zaptrem

I think for scaled_dot_product_attention you need to transpose the second and the third axes of qkv. It requires a different layout.

query = torch.rand(1, 8, 16, 32, dtype=torch.float16, device="cuda")
key = torch.rand(1, 8, 16, 32, dtype=torch.float16, device="cuda")
value = torch.rand(1, 8, 16, 32, dtype=torch.float16, device="cuda")

is_causal = True

with torch.no_grad():
    with torch.backends.cuda.sdp_kernel(enable_math=False):
        ref = F.scaled_dot_product_attention(torch.permute(query, [0, 2, 1, 3]),
                                             torch.permute(key, [0, 2, 1, 3]),
                                             torch.permute(value, [0, 2, 1, 3]),
                                             is_causal=is_causal)
        ref = torch.permute(ref, [0, 2, 1, 3])

    out = flash_attn_func(query, key, value, causal=is_causal)


print(torch.max(torch.abs(out - ref)), torch.mean(torch.abs(out - ref)))

masahi avatar Jul 27 '23 00:07 masahi

@tridao I've checked the differences you suggest, and they are of the same order as you thought they would.

Still could you explain where the difference between FlashAttention in fp16 and standard attention in fp16 is coming from? Cheers

rems75 avatar Oct 12 '23 16:10 rems75

Floating point operations are not associative. Changing the order of the operations will change the output, up to numerical precision. Example

In [1]: import torch

In [2]: a = torch.randn(1024, dtype=torch.float16, device="cuda")

In [3]: out1 = a + 0.3 - 0.2

In [4]: out2 = a - 0.2 + 0.3

In [5]: (out1 - out2).abs().max()
Out[5]: tensor(0.0020, device='cuda:0', dtype=torch.float16)

tridao avatar Oct 12 '23 17:10 tridao

import torch.nn.functional as F

B = 1; T = 3; nh = 32; C = 64

# q = torch.arange(B * T *  nh * C).reshape([B, T, nh, C]).float(); 
# k = torch.arange(B * T *  nh * C).flip(dims=(-1,)).reshape([B, T, nh, C]).float(); 
# v= torch.arange(B * T *  nh * C).reshape([B, T, nh, C]).float() ** 1.5
q = torch.randn([B, T, nh, C])
k = torch.randn([B, T, nh, C])
v = torch.randn([B, T, nh, C])


out = F.scaled_dot_product_attention(q.permute([0, 2, 1, 3]).cuda(), k.permute([0, 2, 1, 3]).cuda(), v.permute([0, 2, 1, 3]).cuda(), dropout_p=0.0, is_causal=False).permute([0, 2, 1, 3])

import math
def func(q, k, v, use_causal=False):
    q, k, v = q.permute([0, 2, 1, 3]), k.permute([0, 2, 3, 1]), v.permute([0, 2, 1, 3])
    # attn = (q @ k.transpose([0, 1, 3, 2])) * (1.0 + math.sqrt(k.shape[-1]))
    attn = (q @ k) * (1.0 / math.sqrt(k.shape[-1]))
    if use_causal:
        attn = masked_fill(attn, self.causal_mask[:, :, :T, :T] == 0, float('-inf'))
    attn = F.softmax(attn, dim=-1)
    out = attn @ v # (B, nh, T, T) x (B, nh, T, c) -> (B, nh, T, c)
    out = out.permute([0, 2, 1, 3]) # (B, T, nh, c)
    return out

out2 = func(q.cuda(), k.cuda(), v.cuda(), use_causal=False)

print((out2 - out).abs().mean()) # about 0.3

image I use torch.float32 in unit test and find the difference between normal attention and flash attention is about 0.3. I think the precision is enough, but the difference is large. Could someone tell me why ?

kaixinbear avatar Jan 03 '24 12:01 kaixinbear

With torch.float32 F.scaled_dot_product_attention does not call FlashAttention (only implemented for fp16 and bf16). You can ask in the Pytorch github.

tridao avatar Jan 03 '24 18:01 tridao

should I be worried about the difference?

chenhuiapp avatar May 05 '24 20:05 chenhuiapp