stable-diffusion icon indicating copy to clipboard operation
stable-diffusion copied to clipboard

use torch's built-in Multihead attention class for self-attention

Open Birch-san opened this issue 1 year ago • 3 comments

CompVis stable-diffusion uses lucidrains' multi-head attention implementation from perceiver_pytorch.

there's actually a built-in MultiheadAttention class in torch that we could be using.
it even has a fast-path for self-attention, which delegates to torch._native_multi_head_attention.

here I've implemented PyTorch MultiheadAttention for the Unet's self-attention layers, which are perhaps the main perf bottleneck in stable-diffusion (because they do giant matrix multiplies — a 512x512 image can reach token counts as high as 4096, though admittedly you could reduce this with ToMe token merging).

on MPS, latest PyTorch nightly 1.14.0.dev20221103, running 8 Heun steps…

original lucidrains einsum:
10.3 secs

MultiheadAttention "slow" path:
12.0 secs

MultiheadAttention "fast" path (torch._native_multi_head_attention):
37.1 secs

no speed improvement on MPS backend; einsum is still 16% faster.
but a CUDA user should try this out and see if there's any improvement! well, maybe they're too busy enjoying Flash Attention.
I wonder whether it gives any speedup on CPU? those benchmarks would take a while to run though.

ordinarily the "fast" path (torch._native_multi_head_attention) would not be accessible on MPS, but I forced PyTorch to try, by modifying torch.nn.modules.activation.py like so:

image

but it's about 3x slower so best not to bother.

and yes, MultiheadAttention outputs the same image as the original einsum. so the implementation seems correct.

Birch-san avatar Nov 04 '22 00:11 Birch-san