stable-diffusion-webui icon indicating copy to clipboard operation
stable-diffusion-webui copied to clipboard

Add scaled dot product attention

Open pamparamm opened this issue 1 year ago • 13 comments

Describe what this pull request is trying to achieve.

Adds support for scaled dot product attention. It performs on par with xformers and doesn't require any side library to work. It can also produce deterministic results during inference. Works with ROCm (though VRAM consumption is still huge).

To use it, remove all xformers-related flags and add flag --opt-sdp-attention to your webui-user file. You can also use --opt-sdp-no-mem-attention flag to get deterministic results with small performance degradation and increased VRAM consumption.

Additional notes and description of your changes

SDP requires PyTorch version >= 2.x, so I've made a check to block it's usage on older/unsupported versions.

Training is not working for me without disabling flash attention, and without flash attention it performs slower than xformers, so it's not currently possible to completely replace xformers on PyTorch 2.x

More info about scaled_dot_product_attention.

Environment this was tested in

  • OS: Windows
  • Torch: 1.13.1+cu117, 2.0.0.dev20230226+cu118
  • Graphics card: NVIDIA RTX 4090, NVIDIA RTX 3070
  • Tested on different samplers, LoRAs, hypernetworks, models (based on sd1.x and sd2.x) and ControlNets.
  • Other test environments mentioned in this PR's conversation

pamparamm avatar Mar 06 '23 20:03 pamparamm

I'm seeing on-par and even slightly better with sdp attention vs xformers:

CleanShot 2023-03-06 at 12 33 04@2x

dsully avatar Mar 06 '23 20:03 dsully

quick test using torch 2.1.0.dev20230305, cuda 11.8 and cudnn 8.8
no issues with sdp so far and performance is on-par with xformers (within margin of error) across different batch sizes
gpu memory usage is also comparable (same for active set and slightly higher for peak set)

vladmandic avatar Mar 06 '23 21:03 vladmandic

Nice work! According to https://huggingface.co/docs/diffusers/optimization/torch2.0, the performance should be on-par with xformers because they are basically same thing. But I think using SDPA has two main benefits: 1) not need to install xformers, instead you need upgrade pytorch to 2.0+ (which may easier on linux); 2) it is compatible with the new torch.compile() in pytorch 2.0+, with will make SD a little more faster.

WuSiYu avatar Mar 07 '23 11:03 WuSiYu

torch: 2.1.0.dev20230304+rocm5.4.2 XFX Radeon RX 6900 XT MERC 319 Black LTD, 16368 MiB RAM

euler a, 20 samples:
    --opt-sub-quad-attention:
        1.5 512:
            4.00 sec, 2976 MiB

        2.1 512:
            3.55 sec, 3262 MiB

        2.0 768:
            10.35 sec, 3972 MiB

        2.0 1024:
            28.00 sec, 4430 MiB

    --opt-sdp-attention:
        1.5 512:
            3.20 sec, 3262 MiB

        2.1 512:
            3.00 sec, 3388 MiB

        2.0 768:
            6.30sec, 6166 MiB

        2.0 1024:
            17 sec (??, often runs out of memory), one time was 13626 MiB
        
    --opt-sdp-attention --medvram:
        1.5 512:
            4.30 sec, 2394 MiB

        2.1 512:
            3.90 sec, 2912 MiB

        2.0 768:
            7.35 sec, 5074 MiB

        2.0 1024:
            19.00 sec, 10376 MiB

So it's a lot faster!

But uses a lot more more VRAM when resolution is increased compared to opt-sub-quad-attention.

CapsAdmin avatar Mar 08 '23 21:03 CapsAdmin

If it can solve the problem of repetitiveness, it is better.

Sakura-Luna avatar Mar 09 '23 10:03 Sakura-Luna

@Sakura-Luna It can't, images with same seeds are still non-deterministic. It's better (at least for inference tasks) as it doesn't require xformers library for users with pytorch 2.x.

pamparamm avatar Mar 09 '23 14:03 pamparamm

I checked the backend of sdp, it is the same as xformers, it cannot be reproduced because it uses memory efficient. It needs to add parameters like --opt-sdp-no-mem-attention and then add torch.backends.cuda.enable_mem_efficient_sdp(False).

It gives completely fixed results with no noticeable performance loss, the downside is that it uses more graphics card memory. Essentially, PyTorch 2.x users will not need xformers at all.

Sakura-Luna avatar Mar 10 '23 06:03 Sakura-Luna

Maybe you should separate sdp and sdp-no-mem, it helps to simplify usage parameters.

Sakura-Luna avatar Mar 10 '23 07:03 Sakura-Luna

Essentially, PyTorch 2.x users will not need xformers at all.

I'm getting strange runtime exception when training with and without mem efficient sdp, I've mentioned it here. I believe it's related to this issue, since i'm on sm89 card. So at least for now, xformers is still useful,

pamparamm avatar Mar 10 '23 07:03 pamparamm

Maybe you should separate sdp and sdp-no-mem, it helps to simplify usage parameters.

I've made arguments similar to existing --xformers and --xformers-flash-attention, but perhaps you're right.

pamparamm avatar Mar 10 '23 07:03 pamparamm

I'm getting strange runtime exception when training with and without mem efficient sdp, I've mentioned it here. I believe it's related to this issue, since i'm on sm89 card. So at least for now, xformers is still useful,

It just looks like some devices are being ignored, have you tried disabling flash attention to fix it?

Sakura-Luna avatar Mar 10 '23 08:03 Sakura-Luna

It just looks like some devices are being ignored, have you tried disabling flash attention to fix it?

Yep, works fine but it's noticeably slower than xformers.

pamparamm avatar Mar 10 '23 08:03 pamparamm

If someone can extend flash attention, we will have the possibility to get better results.

Sakura-Luna avatar Mar 10 '23 08:03 Sakura-Luna

So it's a lot faster!

But uses a lot more more VRAM when resolution is increased compared to opt-sub-quad-attention.

Why are you using opt-sub-quad-attention? The default opt-split-attention works on the 6900xt

feffy380 avatar Mar 28 '23 23:03 feffy380

So it's a lot faster! But uses a lot more more VRAM when resolution is increased compared to opt-sub-quad-attention.

Why are you using opt-sub-quad-attention? The default opt-split-attention works on the 6900xt

I assumed it was the default but I think I was wrong. Sub quad seems to be the slowest of them all but uses least amount of vram.

Testing further it seems like opt sub quad attention is only slightly faster than the others. :(

When I was testing in the post, ldm.modules.diffusionmodules.model.AttnBlock.forward was also not overridden. I tested with and without a few days ago and I couldn't observe any difference in speed nor could I see any difference in the image.

CapsAdmin avatar Mar 29 '23 08:03 CapsAdmin

I still find it odd that sdp attention is advertised as memory efficient, yet actually increases vram usage on AMD. I get the same behavior with diffusers

feffy380 avatar Jul 03 '23 10:07 feffy380