long-context-attention icon indicating copy to clipboard operation
long-context-attention copied to clipboard

_scaled_dot_product_efficient_attention a bug for lse

Open neonhuang opened this issue 10 months ago • 2 comments

https://github.com/feifeibear/long-context-attention/blob/0.6.0/yunchang/kernels/attention.py#L47 torch版本时torch 2.3 您好,我实验发现对于 下面的case batch_size = 1 num_heads = 2 head_dim = 128 seq_len = 16

torch.ops.aten._scaled_dot_product_efficient_attention 的lse返回值有bug 这里建议使用 _scaled_dot_product_flash_attention这个函数,返回lse

import torch batch_size = 1 num_heads = 2 head_dim = 128 seq_len = 16

dtype = torch.float16 device = 'cuda'

torch.manual_seed(42) query = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) key = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device) value = torch.rand(batch_size, num_heads, seq_len, head_dim, dtype=dtype, device=device)

out1 = torch.nn.functional.scaled_dot_product_attention(query,key,value)

out2, lse2 = torch.ops.aten._scaled_dot_product_flash_attention(query, key, value)[:2] print(f'lse2: {lse2}')

out3, lse3 = torch.ops.aten._scaled_dot_product_efficient_attention(query, key,value, attn_bias=None, compute_log_sumexp=True)[:2] print(f'lse3: {lse3}')

print(f'Result: {torch.allclose(out1, out2, rtol=1e-3, atol=1e-3)}') print(f'Result: {torch.allclose(out1, out3, rtol=1e-3, atol=1e-3)}') print(f'Result: {torch.allclose(lse2, lse3, rtol=1e-3, atol=1e-3)}')

结果: lse2: tensor([[[5.4660, 5.7685, 5.4396, 5.6371, 5.3691, 5.4923, 5.4666, 5.4459, 5.5439, 5.7281, 5.8045, 5.6622, 5.7923, 5.6987, 5.5620, 5.5473], [5.4989, 5.8029, 5.5886, 5.5052, 5.6427, 5.5984, 5.7117, 5.4015, 5.6134, 5.5992, 5.4512, 5.8386, 5.8852, 5.3351, 5.6285, 5.6732]]], device='cuda:0')

lse3: tensor([[[5.4660, 5.7685, 5.4396, 5.6371, 5.3691, 5.4923, 5.4666, 5.4459, 5.5439, 5.7281, 5.8045, 5.6622, 5.7923, 5.6987, 5.5620, 5.5473, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf], [5.4989, 5.8029, 5.5886, 5.5052, 5.6427, 5.5984, 5.7117, 5.4015, 5.6134, 5.5992, 5.4512, 5.8386, 5.8852, 5.3351, 5.6285, 5.6732, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf, inf]]], device='cuda:0') Result: True Result: True

neonhuang avatar Feb 08 '25 08:02 neonhuang

https://github.com/feifeibear/long-context-attention/blob/0.6.0/yunchang/kernels/attention.py#L47,这一行可以改为如下: out, lse = torch.ops.aten._scaled_dot_product_flash_attention(q.permute(0, 2, 1, 3), key.permute(0, 2, 1, 3), value.permute(0, 2, 1, 3), dropout_p=dropout_p, is_causal=causal, scale=softmax_scale, )[:2]

neonhuang avatar Feb 08 '25 09:02 neonhuang

感谢 @neonhuang !您能交一个 MR 么?如果 torch 版本<2.3 执行你粘贴的代码?

feifeibear avatar Feb 13 '25 09:02 feifeibear