xformers
xformers copied to clipboard
output from memory_efficient_attention not exactly the same with pytorch equivalent implementation
❓ Questions and Help
Hi, I tested memory_efficient_attention with the pytorch equivalent implementation in the doc, and found they are not exactly the same. The code:
def attention_e(self, q, k, v):
scale = 1 / q.shape[-1] ** 0.5
q = q * scale
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
a = q @ k.transpose(-2, -1)
a = a.softmax(-1)
return (a @ v).transpose(1, 2)
b, n_head, head_dim, seq_len = 4, 8, 16, 128
q = torch.rand(b, seq_len, n_head, head_dim).to('cuda').half()
k = torch.rand(b, seq_len, n_head, head_dim).to('cuda').half()
v = torch.rand(b, seq_len, n_head, head_dim).to('cuda').half()
from xformers.ops import memory_efficient_attention
o1 = memory_efficient_attention(q, k, v)
o2 = attention_e(q, k, v)
print(torch.norm(o1 - o2))
and the output was:
tensor(0.0183, device='cuda:0', dtype=torch.float16)
Is that an expected feature or I did sth wrong? Thanks!