memory-efficient-attention
memory-efficient-attention copied to clipboard
Improve performance via batched-matmul and fused multiplies
Many thanks for providing this reference implementation.
I tried integrating this into stable-diffusion / diffusers. A fix was required to make it work on Mac (PyTorch MPS backend):
https://github.com/Birch-san/diffusers/pull/1/commits/04372140a25d7f53549175f1f196599c3e9bf3a5
Knowing that computing attention via baddbmm()
+bmm()
can outperform einsum by 18%: I tried to rewrite the algorithm to use those.
I compared the speed of my optimized version, against the implementation in this repository. this result is for "everything fits in one chunk" perf (i.e. chunk size = max token length). I was unable to compare chunked perf, because although I got chunking working in my version: I wasn't able to get it working in the version in this repository (got some unexpected-shape tensors returned).
compared to the implementation in this repository:
my optimized version achieves a 2.78x speedup in the time it took to generate a 512x512 image with stable-diffusion v2.1-base (i.e. 4096 vision tokens, 5 attention heads, batch size of 2 due to CFG).
here's my optimized implementation:
https://github.com/Birch-san/diffusers/pull/1
batched matmuls require a 3D tensor, i.e. [batch * num_heads, tokens, channels_per_head]
.
code that currently integrates agains this repository's [batch, q_length, num_heads, qk_depth_per_head]
format can migrate those tensors to the [batch * num_heads, q_length, channels_per_head]
format favoured by my implementation like so:
query = query.transpose(1,2).flatten(end_dim=1)
key = key.transpose(1,2).flatten(end_dim=1)
value = value.transpose(1,2).flatten(end_dim=1)
the result that's returned, remains in [batch * num_heads, q_length, qk_depth_per_head]
format, and can be restored to [batch, q_length, num_heads, qk_depth_per_head]
format like so:
result.unflatten(0, (-1, attn.heads)).transpose(1,2)
I think a further speedup is possible too: by working out when chunking is not needed: we can compute whether unchunked attention would fit into memory, and prefer unchunked attention as a fast-path where possible. this will be useful in a Unet, which runs attention at various resolutions.
EDIT:
I have now added fast-paths for:
- skipping kv-chunking when
kv_chunk_size >= k_tokens
- this turns the algorithm into "attention slicing"
- skipping q-chunking when
q_chunk_size >= q_tokens
- skipping all chunking when the
kv_chunk_size >= k_tokens
andq_chunk_size >= q_tokens
- skipping all chunking when the
[email protected]
matmul requires fewer bytes than a user-provided threshold