memory-efficient-attention icon indicating copy to clipboard operation
memory-efficient-attention copied to clipboard

Improve performance via batched-matmul and fused multiplies

Open Birch-san opened this issue 1 year ago • 10 comments

Many thanks for providing this reference implementation.

I tried integrating this into stable-diffusion / diffusers. A fix was required to make it work on Mac (PyTorch MPS backend):
https://github.com/Birch-san/diffusers/pull/1/commits/04372140a25d7f53549175f1f196599c3e9bf3a5

Knowing that computing attention via baddbmm()+bmm() can outperform einsum by 18%: I tried to rewrite the algorithm to use those.

I compared the speed of my optimized version, against the implementation in this repository. this result is for "everything fits in one chunk" perf (i.e. chunk size = max token length). I was unable to compare chunked perf, because although I got chunking working in my version: I wasn't able to get it working in the version in this repository (got some unexpected-shape tensors returned).

compared to the implementation in this repository:
my optimized version achieves a 2.78x speedup in the time it took to generate a 512x512 image with stable-diffusion v2.1-base (i.e. 4096 vision tokens, 5 attention heads, batch size of 2 due to CFG).

here's my optimized implementation:
https://github.com/Birch-san/diffusers/pull/1

batched matmuls require a 3D tensor, i.e. [batch * num_heads, tokens, channels_per_head].

code that currently integrates agains this repository's [batch, q_length, num_heads, qk_depth_per_head] format can migrate those tensors to the [batch * num_heads, q_length, channels_per_head] format favoured by my implementation like so:

query = query.transpose(1,2).flatten(end_dim=1)
key = key.transpose(1,2).flatten(end_dim=1)
value = value.transpose(1,2).flatten(end_dim=1)

the result that's returned, remains in [batch * num_heads, q_length, qk_depth_per_head] format, and can be restored to [batch, q_length, num_heads, qk_depth_per_head] format like so:

result.unflatten(0, (-1, attn.heads)).transpose(1,2)

I think a further speedup is possible too: by working out when chunking is not needed: we can compute whether unchunked attention would fit into memory, and prefer unchunked attention as a fast-path where possible. this will be useful in a Unet, which runs attention at various resolutions.

EDIT:
I have now added fast-paths for:

  • skipping kv-chunking when kv_chunk_size >= k_tokens
    • this turns the algorithm into "attention slicing"
  • skipping q-chunking when q_chunk_size >= q_tokens
  • skipping all chunking when the kv_chunk_size >= k_tokens and q_chunk_size >= q_tokens
  • skipping all chunking when the [email protected] matmul requires fewer bytes than a user-provided threshold

Birch-san avatar Dec 27 '22 23:12 Birch-san