xformers icon indicating copy to clipboard operation
xformers copied to clipboard

`ScaledDotProduct` with attention mask returns different result as standard attention

Open amyxlu opened this issue 5 months ago • 3 comments

Great work on this project! I'm doing some benchmarking with key padding masking, but am getting a different answer for xformers.components.attention.ScaledDotProduct as compared to standard attention. Could you help clarify how this component should be used?

I suspect part of the issue could be ambiguity in what the shape should be for q, k, v, and att_mask (see commented-out lines in the implementation below):

ScaledDotProduct

 def xformers_scaled_dot_product_attention(self, x, mask=None):
        b, l, _, h = *x.shape, self.heads
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        # q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=h), (q, k, v))

        # xformers scaled dot product attention fn applies the scaling by dim_head ** -0.5

        if mask.ndim == 2:
            mask = repeat(mask, "b l -> b l l_prime", l_prime=l)
            # mask = repeat(mask, "b l -> b h l l_prime", h=h, l_prime=l)

        out = self.xformers_scaled_dot_product_fn(q, k, v, att_mask=mask)
        # out = rearrange(out, "b h l d -> b l (h d)", h=h)
        return self.to_out(out)

FWIW, my memory_efficient_attention does get the same result as standard attention:

 def xformers_memory_efficient_attention(self, x, mask=None):
      dtype, device = x.dtype, x.device
      b, l, _, h = *x.shape, self.heads
      q, k, v = self.to_qkv(x).chunk(3, dim=-1)
      q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b l h d", h=h), (q, k, v))

      # Create attn_bias from padding mask
      if mask is not None:
          attn_bias = torch.zeros_like(mask, dtype=dtype)
          attn_bias = attn_bias.masked_fill(~mask, float('-inf'))
          attn_bias = rearrange(attn_bias, "b l -> b () () l")  # Shape: (batch_size, 1, 1, seq_len)
          attn_bias = attn_bias.repeat(1, h, l, 1)
          attn_bias = attn_bias.to(device)
      else:
          attn_bias = None

      out = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
      out = rearrange(out, "b l h d -> b l (h d)")
      return self.to_out(out)

  def standard_attention(self, x, mask=None):
      h = self.heads
      q, k, v = self.to_qkv(x).chunk(3, dim=-1)
      q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=h), (q, k, v))

      q = q * self.scale

      sim = einsum("b h i d, b h j d -> b h i j", q, k)
      mask_value = -torch.finfo(sim.dtype).max

      if exists(mask):
          mask = rearrange(mask, "b j -> b () () j")
          sim = sim.masked_fill(~mask, mask_value)

      attn = sim.softmax(dim=-1)
      attn = self.dropout(attn)

      out = einsum("b h i j, b h j d -> b h i d", attn, v)
      out = rearrange(out, "b h l d -> b l (h d)", h=h)
      return self.to_out(out)

The results I get are fastest with ScaledDotProduct, so I'd like to see if I can get that working:

Standard Multihead Attention - Time: 0.0953s, Memory: 77483.86MB
xFormers Memory Efficient Attention - Time: 0.0519s, Memory: 65258.86MB
xFormers Scaled Dot Product Attention - Time: 0.0295s, Memory: 51882.86MB

Thanks for the help and congrats again on the great work! Cc @danthe3rd

amyxlu avatar Aug 27 '24 05:08 amyxlu