xformers
xformers copied to clipboard
`ScaledDotProduct` with attention mask returns different result as standard attention
Great work on this project! I'm doing some benchmarking with key padding masking, but am getting a different answer for xformers.components.attention.ScaledDotProduct
as compared to standard attention. Could you help clarify how this component should be used?
I suspect part of the issue could be ambiguity in what the shape should be for q
, k
, v
, and att_mask
(see commented-out lines in the implementation below):
ScaledDotProduct
def xformers_scaled_dot_product_attention(self, x, mask=None):
b, l, _, h = *x.shape, self.heads
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
# q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=h), (q, k, v))
# xformers scaled dot product attention fn applies the scaling by dim_head ** -0.5
if mask.ndim == 2:
mask = repeat(mask, "b l -> b l l_prime", l_prime=l)
# mask = repeat(mask, "b l -> b h l l_prime", h=h, l_prime=l)
out = self.xformers_scaled_dot_product_fn(q, k, v, att_mask=mask)
# out = rearrange(out, "b h l d -> b l (h d)", h=h)
return self.to_out(out)
FWIW, my memory_efficient_attention
does get the same result as standard attention:
def xformers_memory_efficient_attention(self, x, mask=None):
dtype, device = x.dtype, x.device
b, l, _, h = *x.shape, self.heads
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b l h d", h=h), (q, k, v))
# Create attn_bias from padding mask
if mask is not None:
attn_bias = torch.zeros_like(mask, dtype=dtype)
attn_bias = attn_bias.masked_fill(~mask, float('-inf'))
attn_bias = rearrange(attn_bias, "b l -> b () () l") # Shape: (batch_size, 1, 1, seq_len)
attn_bias = attn_bias.repeat(1, h, l, 1)
attn_bias = attn_bias.to(device)
else:
attn_bias = None
out = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
out = rearrange(out, "b l h d -> b l (h d)")
return self.to_out(out)
def standard_attention(self, x, mask=None):
h = self.heads
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=h), (q, k, v))
q = q * self.scale
sim = einsum("b h i d, b h j d -> b h i j", q, k)
mask_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = rearrange(mask, "b j -> b () () j")
sim = sim.masked_fill(~mask, mask_value)
attn = sim.softmax(dim=-1)
attn = self.dropout(attn)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h l d -> b l (h d)", h=h)
return self.to_out(out)
The results I get are fastest with ScaledDotProduct
, so I'd like to see if I can get that working:
Standard Multihead Attention - Time: 0.0953s, Memory: 77483.86MB
xFormers Memory Efficient Attention - Time: 0.0519s, Memory: 65258.86MB
xFormers Scaled Dot Product Attention - Time: 0.0295s, Memory: 51882.86MB
Thanks for the help and congrats again on the great work! Cc @danthe3rd