stable-diffusion-webui
stable-diffusion-webui copied to clipboard
Add scaled dot product attention
Describe what this pull request is trying to achieve.
Adds support for scaled dot product attention. It performs on par with xformers and doesn't require any side library to work. It can also produce deterministic results during inference. Works with ROCm (though VRAM consumption is still huge).
To use it, remove all xformers-related flags and add flag --opt-sdp-attention
to your webui-user file.
You can also use --opt-sdp-no-mem-attention
flag to get deterministic results with small performance degradation and increased VRAM consumption.
Additional notes and description of your changes
SDP requires PyTorch version >= 2.x, so I've made a check to block it's usage on older/unsupported versions.
Training is not working for me without disabling flash attention, and without flash attention it performs slower than xformers, so it's not currently possible to completely replace xformers on PyTorch 2.x
More info about scaled_dot_product_attention.
Environment this was tested in
- OS: Windows
- Torch: 1.13.1+cu117, 2.0.0.dev20230226+cu118
- Graphics card: NVIDIA RTX 4090, NVIDIA RTX 3070
- Tested on different samplers, LoRAs, hypernetworks, models (based on sd1.x and sd2.x) and ControlNets.
- Other test environments mentioned in this PR's conversation
I'm seeing on-par and even slightly better with sdp attention vs xformers:

quick test using torch 2.1.0.dev20230305, cuda 11.8 and cudnn 8.8
no issues with sdp
so far and performance is on-par with xformers
(within margin of error) across different batch sizes
gpu memory usage is also comparable (same for active set and slightly higher for peak set)
Nice work! According to https://huggingface.co/docs/diffusers/optimization/torch2.0, the performance should be on-par with xformers because they are basically same thing. But I think using SDPA has two main benefits: 1) not need to install xformers, instead you need upgrade pytorch to 2.0+ (which may easier on linux); 2) it is compatible with the new torch.compile() in pytorch 2.0+, with will make SD a little more faster.
torch: 2.1.0.dev20230304+rocm5.4.2 XFX Radeon RX 6900 XT MERC 319 Black LTD, 16368 MiB RAM
euler a, 20 samples:
--opt-sub-quad-attention:
1.5 512:
4.00 sec, 2976 MiB
2.1 512:
3.55 sec, 3262 MiB
2.0 768:
10.35 sec, 3972 MiB
2.0 1024:
28.00 sec, 4430 MiB
--opt-sdp-attention:
1.5 512:
3.20 sec, 3262 MiB
2.1 512:
3.00 sec, 3388 MiB
2.0 768:
6.30sec, 6166 MiB
2.0 1024:
17 sec (??, often runs out of memory), one time was 13626 MiB
--opt-sdp-attention --medvram:
1.5 512:
4.30 sec, 2394 MiB
2.1 512:
3.90 sec, 2912 MiB
2.0 768:
7.35 sec, 5074 MiB
2.0 1024:
19.00 sec, 10376 MiB
So it's a lot faster!
But uses a lot more more VRAM when resolution is increased compared to opt-sub-quad-attention.
If it can solve the problem of repetitiveness, it is better.
@Sakura-Luna It can't, images with same seeds are still non-deterministic. It's better (at least for inference tasks) as it doesn't require xformers
library for users with pytorch 2.x.
I checked the backend of sdp, it is the same as xformers, it cannot be reproduced because it uses memory efficient. It needs to add parameters like --opt-sdp-no-mem-attention
and then add torch.backends.cuda.enable_mem_efficient_sdp(False)
.
It gives completely fixed results with no noticeable performance loss, the downside is that it uses more graphics card memory. Essentially, PyTorch 2.x users will not need xformers at all.
Maybe you should separate sdp
and sdp-no-mem
, it helps to simplify usage parameters.
Essentially, PyTorch 2.x users will not need xformers at all.
I'm getting strange runtime exception when training with and without mem efficient sdp, I've mentioned it here. I believe it's related to this issue, since i'm on sm89 card. So at least for now, xformers is still useful,
Maybe you should separate
sdp
andsdp-no-mem
, it helps to simplify usage parameters.
I've made arguments similar to existing --xformers
and --xformers-flash-attention
, but perhaps you're right.
I'm getting strange runtime exception when training with and without mem efficient sdp, I've mentioned it here. I believe it's related to this issue, since i'm on sm89 card. So at least for now, xformers is still useful,
It just looks like some devices are being ignored, have you tried disabling flash attention to fix it?
It just looks like some devices are being ignored, have you tried disabling flash attention to fix it?
Yep, works fine but it's noticeably slower than xformers.
If someone can extend flash attention, we will have the possibility to get better results.
So it's a lot faster!
But uses a lot more more VRAM when resolution is increased compared to opt-sub-quad-attention.
Why are you using opt-sub-quad-attention? The default opt-split-attention works on the 6900xt
So it's a lot faster! But uses a lot more more VRAM when resolution is increased compared to opt-sub-quad-attention.
Why are you using opt-sub-quad-attention? The default opt-split-attention works on the 6900xt
I assumed it was the default but I think I was wrong. Sub quad seems to be the slowest of them all but uses least amount of vram.
Testing further it seems like opt sub quad attention is only slightly faster than the others. :(
When I was testing in the post, ldm.modules.diffusionmodules.model.AttnBlock.forward
was also not overridden. I tested with and without a few days ago and I couldn't observe any difference in speed nor could I see any difference in the image.
I still find it odd that sdp attention is advertised as memory efficient, yet actually increases vram usage on AMD. I get the same behavior with diffusers