stable-diffusion.cpp icon indicating copy to clipboard operation
stable-diffusion.cpp copied to clipboard

rescuing flash attention

Open Green-Sky opened this issue 1 year ago • 15 comments

I ported over the flash attention code changes form #221 . ~This does not yet fix the existing code behind the define, which broke when gg removed the code from ggml assuming noone was using it :)~ ~The old broken code behind the define also only gets invoked by the VAE.~ The new code gets used by UNET, DiT and clip/t5. However I have not come across any clip or t5 that just works without extra work, but there are sizable memory reductions for ~sd1,~ sd2, sdxl and flux compute buffers, as well as a speed increase on cuda. I switched VAE over to using the new attention code, but flash attention makes not difference, so it is disabled there. There is alot more left on the table if we employ padding, so flash attention can be applied to more ops.

Flash attention for diffusion models is available via a runtime flag --diffusion-fa.

flux1-schnell performance numbers

all tests done on a debug build, so the real number are likely a bit better. 4 steps, cfg-scale of 1

CPU

512x512:

compute buffer size: 397MB -> 247MB speed q3_k: 80.00s/it -> 83.38s/it

768x768:

compute buffer size: 1103MB -> 503MB speed q3_k: 197.11s/it -> 209.16s/it

CUDA

512x512:

compute buffer size: 398MB -> 248MB speed q3_k: 1.98s/it -> 1.69s/it speed q4_k: 1.84s/it -> 1.54s/it

768x768:

compute buffer size: 1105MB -> 505MB speed q3_k: 4.58s/it -> 3.58s/it speed q4_k: OOM -> 3.29s/it

direct comparison of the cuda images:

nofa fa
cuda_nofa_512x512 cuda_fa_512x512

difference, darker/colorful is worse:

error error^2
image image

SD2 turbo

8 steps, cfg-scale of 1

CPU

512x512:

compute buffer size: 367MB -> 130MB speed q8_0: 6.40s/it -> 6.65s/it

768x768:

compute buffer size: 1718MB -> 294MB speed q8_0: 20.59s/it -> 22.41s/it

CUDA

512x512:

compute buffer size: 367MB -> 130MB speed q8_0: 6.24it/s -> 8.17it/s

768x768:

compute buffer size: 1718MB -> 295MB speed q8_0: 1.84it/s -> 3.17it/s

direct comparison of the cuda images:

nofa fa
cuda_nofa_512x512 cuda_fa_512x512

difference, darker/colorful is worse:

error error^2
image image

SDXL realvisxl_v50Lightning

6 steps, cfg-scale of 1.8 dpm++2mv2 karras

CPU

512x512:

compute buffer size: 131MB -> 131MB speed q8_0: 15.04s/it -> 15.21s/it

768x768:

compute buffer size: 330MB -> 280MB speed q8_0: 37.36s/it -> 39.30s/it

CUDA

512x512:

compute buffer size: 132MB -> 132MB speed q8_0: 2.01it/s -> 2.36it/s

768x768:

compute buffer size: 331MB -> 280MB speed q8_0: 1.23s/it -> 1.15s/it

direct comparison of the cuda images:

nofa fa
cuda_nofa_512x512 cuda_fa_512x512

difference, darker/colorful is worse:

error error^2
image image

There still is the opportunity to pad some tensors to make them fit.

TODO

  • [x] remove the define
  • [x] add with a runtime switch
    • [x] diffusion model
    • [ ] ~vae~
  • [x] more exhaustive testing with supported models
  • [x] add more then just flux numbers to the op
  • [x] add image comparisons
  • [x] update docs

Please test this code.

props to @FSSRepo for having the code laying around

fixes: https://github.com/leejet/stable-diffusion.cpp/issues/297

images: sd2_turbo.zip flux1-schnell-q4_k.zip flux1-schnell-q3_k.zip

update: added sd2 and sdxl numbers update2: see rocm tests down in the thread. it works but behaves similar to cpu udpate3: added flash_attn param and expose via --diffusion-fa runtime flag for supported models.

Green-Sky avatar Sep 01 '24 13:09 Green-Sky

@Green-Sky -- would this enable flash attention for Vulkan builds as well?

theaerotoad avatar Sep 03 '24 22:09 theaerotoad

@Green-Sky -- would this enable flash attention for Vulkan builds as well?

No, sadly not. Vulkan does not implement GGML_OP_FLASH_ATTN_EXT. However, it looks like cuda, cuda built as rocm (and musa?) and metal all support it.

Would be cool if someone could try rocm and metal builds.

Green-Sky avatar Sep 04 '24 07:09 Green-Sky

I just tested SD3 (2B), and it appears the kv dimensions are not multiples of 256, so flash attention wont work on the mmdit without padding.

Green-Sky avatar Sep 06 '24 08:09 Green-Sky

I wanted to test the PR with the ROCM backend but I couldn't get it to build. Did you somehow add back the missing ggml_flash_attn function that is used in ggml_extend.hpp? I am not familiar with ggml so it would be nice if you could explain how you got the CUDA backend working.

MineGame159 avatar Sep 06 '24 13:09 MineGame159

I wanted to test the PR with the ROCM backend but I couldn't get it to build. Did you somehow add back the missing ggml_flash_attn function that is used in ggml_extend.hpp? I am not familiar with ggml so it would be nice if you could explain how you got the CUDA backend working.

Oh no, sorry for the confusion, dont use the old define or cmake option, I did not add that back in. Just build it as-is and it has flash attention (for diffusion models only) enabled. The old code that gets enabled with the define is what would get used by VAE, but I have not touched that part yet.

edit: I might push a change soon, where you can enable flash attention for diffusion models via a command line option.

Green-Sky avatar Sep 06 '24 13:09 Green-Sky

Oh interesting. Because I already tried a comparison between the current master and your PR applied without the cmake flag. And the results didn't look like it was enabled. Here are the results (no Flash Attention is how I marked that I didn't enable the flag). I used the flux1-schnell model using q8_0.

Master - no Flash Attention

total params memory size = 21471.11MB (VRAM 12152.28MB, RAM 9318.83MB) sampling completed, taking 149.66s - 18.69s/it

PR - no Flash Attention

total params memory size = 21471.11MB (VRAM 12152.28MB, RAM 9318.83MB) sampling completed, taking 173.13s - 21.62s/it

MineGame159 avatar Sep 06 '24 14:09 MineGame159

@MineGame159 I see, so speed went down on rocm... Also, run it with -v and look for

[DEBUG] ggml_extend.hpp:739  - attention_ext L_q:2304 L_k:2304 n_head:24 C:3072 d_head:128 N:1
[DEBUG] ggml_extend.hpp:763  - using flash attention

If an attention_ext line with tensor sizes is follows by using flash attention, then that specific attention was converted to a flash attention.

Also look for flux compute buffer size: 456.75 MB(VRAM) (in the spam with -v) to see how the compute buffer is affected. Total params size does not change with this patch. :smiley:

Green-Sky avatar Sep 06 '24 14:09 Green-Sky

Oh sorry 😅, I am pretty new to sd.cpp and AI as a whole. And yes, I can see that it is using flash attention.

[DEBUG] ggml_extend.hpp:738  - attention_ext L_k:1280 n_head:24 C:3072 d_head:128
[DEBUG] ggml_extend.hpp:755  - using flash attention

And the compute buffer size: Master - flux compute buffer size: 398.50 MB(VRAM) PR - flux compute buffer size: 248.50 MB(VRAM)

MineGame159 avatar Sep 06 '24 14:09 MineGame159

And the resulting image look close to identical? The numbers look good, the compute buffer size reduction is the same as on cuda. Try larger images :)

Green-Sky avatar Sep 06 '24 14:09 Green-Sky

Tested with a 768x768 image and I can't see any difference. The size of the compute buffer is basically the same as on CUDA for you. Only thing is the decreased speed.

1105.07 MB(VRAM) -> 505.07 MB(VRAM) 41.66s/it -> 57.59s/it

master pr

MineGame159 avatar Sep 06 '24 15:09 MineGame159

I diffed your images.

error error^2
image image

The error is even less then what I have with cuda.

Tested with a 768x768 image and I can't see any difference. The size of the compute buffer is basically the same as on CUDA for you. Only thing is the decreased speed.

Yea, good to know, thanks for testing. In summery: use it if you use cuda and/or need the extra vram savings.

Green-Sky avatar Sep 06 '24 19:09 Green-Sky

edit: I might push a change soon, where you can enable flash attention for diffusion models via a command line option.

I did. I am now looking into vae.

update: flash attention does not seem to improve VAE in anyway + it does not work on cuda, since cuda supports d_head up to 256, but VAE needs 512. So I think I will not enable VAE flash attention, but still remove the old crusty code.

Green-Sky avatar Sep 07 '24 10:09 Green-Sky

Would be nice if someone could test metal.

Green-Sky avatar Sep 07 '24 11:09 Green-Sky

  • Accelerated memory-efficient CPU inference
    • Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image, enabling Flash Attention just requires ~1.8GB.

imma leave @leejet to update this example, not sure where those numbers came from.

Green-Sky avatar Sep 08 '24 07:09 Green-Sky

If someone wants to play with kv-padding, to enable more cases where flash attention can be used, here is a patch:

diff --git a/ggml_extend.hpp b/ggml_extend.hpp
index 8452a0b..518eb6f 100644
--- a/ggml_extend.hpp
+++ b/ggml_extend.hpp
@@ -710,6 +710,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*

     float scale = (1.0f / sqrt((float)d_head));

+    int kv_pad = 0;
     //if (flash_attn) {
     //    LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
     //}
@@ -717,7 +718,14 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
     GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));

     bool can_use_flash_attn = true;
-    can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
+    if (can_use_flash_attn && L_k % 256 != 0) {
+        if (L_k == 77) {
+            kv_pad = 256 - (L_k % 256);
+        } else {
+            can_use_flash_attn = false;
+        }
+    }
+    //can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
     can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check

     // cuda max d_head seems to be 256, cpu does seem to work with 512
@@ -734,11 +742,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
     ggml_tensor* kqv = nullptr;
     //GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
     if (can_use_flash_attn && flash_attn) {
-        //LOG_DEBUG("using flash attention");
+        LOG_DEBUG("using flash attention");
+        if (kv_pad != 0) {
+            LOG_DEBUG("padding kv by %d", kv_pad);
+            k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
+        }
         k = ggml_cast(ctx, k, GGML_TYPE_F16);

         v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3));  // [N, n_head, L_k, d_head]
         v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N);  // [N * n_head, L_k, d_head]
+        if (kv_pad != 0) {
+            v = ggml_pad(ctx, v, 0, kv_pad, 0, 0);
+        }
         v = ggml_cast(ctx, v, GGML_TYPE_F16);

         kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0);

I wont be committing this here since the memory saving are minuscule, or it uses even more memory because of the padding. Plus, the images look way off, so maybe there is a mistake somewhere or it's just the f16 cast in the wrong situation...

Green-Sky avatar Sep 08 '24 16:09 Green-Sky

rebased on master, did not yet look at sd3.5, it is disabled, same as sd3 for now. I still consider this ready for merge as-is, I will probably look at sd3.5 in the coming weeks.

Green-Sky avatar Oct 25 '24 14:10 Green-Sky

Thank you for your contribution

leejet avatar Nov 23 '24 04:11 leejet

~~In the flux model I used, FA doesn't seem to be on even though using -diffusion-fa. The reason is L_k=4208 and L_k % 256 != 0. Any idea how to fix it?~~ Never mind. I figured out. In case any one has the same issue, try the following on top of @Green-Sky 's patch:

if (can_use_flash_attn && L_k % 256 != 0) {
+        if (L_k == 77 || L_k == 4208) {
+            kv_pad = 256 - (L_k % 256);
+        } else {
+            can_use_flash_attn = false;
+        }
+    }

At 832x1216 resolution, diffusion steps went from 30s (FA off) down to 18s (FA on) on my 4090. Big speed up. The quality may have been degraded a bit, but still ok.

@Green-Sky, I think you should commit the patch. It helps with users who want to gain some speed and save VRAM.

bssrdf avatar Dec 14 '24 02:12 bssrdf

That is nice to see. I will try it and test against other models to catch potential degradation. :)

The quality may have been degraded a bit, but still ok. It helps with users who want to gain some speed and save VRAM.

Yea, this usually allows for larger images or higher quants. Good tradeoff.

Green-Sky avatar Dec 14 '24 12:12 Green-Sky

Found another L_k for SDXL

if (L_k == 77 || L_k == 4208 || L_k == 3952) {

bssrdf avatar Dec 14 '24 21:12 bssrdf

@Green-Sky Wondering why you didn't enable FA for SD1.x. Did you face any particular issue?

rmatif avatar Jul 10 '25 20:07 rmatif

I honestly don't remember. I and apparently did not talk about it here. Also, since then, upstream flash attention changed, more shapes have become compatible for some backends AND vulkan gained working flash attention. :) Is there still such a demand for SD1.x ? I can investigate if you want.

Green-Sky avatar Jul 10 '25 20:07 Green-Sky

I'm planning to add FA support to the OpenCL backend in the future. Since the maximum buffer size in OpenCL is quite limited on mobile GPUs (around 800–1024 MB on the latest Adreno), the compute buffer will become a precious resource, especially when adding support for older GPUs, which likely have even less available buffer memory. But that's still a long way off, so no hurry :)

rmatif avatar Jul 10 '25 21:07 rmatif

@rmatif Ok, sd1 has a d_heads of 40, 80 and 160, while ggml flash attention cuda kernels require a multiple of 64. At least that's what my understanding was at the time. But I think there are some cases that can just work. (looking at it rn)

Green-Sky avatar Jul 21 '25 11:07 Green-Sky

In fact, cuda and vulkan fa seem to support:

  • 64
  • 80
  • 96
  • 112
  • 128
  • 192
  • 256

I enabled the 80 case, which does not reduce vram usage, but makes the cuda sampling slightly faster (83.17s -> 79.20s) It is not much since its just 5 out of 32 attention operations.

Details
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:12288 L_k:12288 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:12288 L_k:77 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:12288 L_k:12288 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:12288 L_k:77 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:3072 L_k:3072 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:874  -  -> uses flash attention
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:3072 L_k:77 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:3072 L_k:3072 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:874  -  -> uses flash attention
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:3072 L_k:77 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:768 L_k:768 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:768 L_k:77 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:768 L_k:768 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:768 L_k:77 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:192 L_k:192 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:192 L_k:77 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:768 L_k:768 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:768 L_k:77 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:768 L_k:768 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:768 L_k:77 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:768 L_k:768 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:768 L_k:77 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:3072 L_k:3072 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:874  -  -> uses flash attention
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:3072 L_k:77 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:3072 L_k:3072 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:874  -  -> uses flash attention
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:3072 L_k:77 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:3072 L_k:3072 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:874  -  -> uses flash attention
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:3072 L_k:77 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:12288 L_k:12288 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:12288 L_k:77 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:12288 L_k:12288 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:12288 L_k:77 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:12288 L_k:12288 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844  - attention_ext L_q:12288 L_k:77 n_head:8 C:320 d_head:40 N:1

Green-Sky avatar Jul 21 '25 11:07 Green-Sky

@rmatif embrace your maginal gains for sd1 :) #736

Green-Sky avatar Jul 21 '25 22:07 Green-Sky

@Green-Sky Thank you very much for your reactivity! I'm almost done adding FA support to OpenCL, so your timing is perfect

rmatif avatar Jul 24 '25 02:07 rmatif