stable-diffusion
stable-diffusion copied to clipboard
use torch's built-in Multihead attention class for self-attention
CompVis stable-diffusion uses lucidrains' multi-head attention implementation from perceiver_pytorch
.
there's actually a built-in MultiheadAttention class in torch that we could be using.
it even has a fast-path for self-attention, which delegates to torch._native_multi_head_attention
.
here I've implemented PyTorch MultiheadAttention for the Unet's self-attention layers, which are perhaps the main perf bottleneck in stable-diffusion (because they do giant matrix multiplies — a 512x512 image can reach token counts as high as 4096, though admittedly you could reduce this with ToMe token merging).
on MPS, latest PyTorch nightly 1.14.0.dev20221103
, running 8 Heun steps…
original lucidrains einsum:
10.3 secs
MultiheadAttention "slow" path:
12.0 secs
MultiheadAttention "fast" path (torch._native_multi_head_attention
):
37.1 secs
no speed improvement on MPS backend; einsum is still 16% faster.
but a CUDA user should try this out and see if there's any improvement! well, maybe they're too busy enjoying Flash Attention.
I wonder whether it gives any speedup on CPU? those benchmarks would take a while to run though.
ordinarily the "fast" path (torch._native_multi_head_attention
) would not be accessible on MPS, but I forced PyTorch to try, by modifying torch.nn.modules.activation.py
like so:
![image](https://user-images.githubusercontent.com/6141784/199858792-7562f804-9bb4-456d-a48f-06bbfef2c2ae.png)
but it's about 3x slower so best not to bother.
and yes, MultiheadAttention outputs the same image as the original einsum. so the implementation seems correct.