stable-diffusion.cpp
stable-diffusion.cpp copied to clipboard
rescuing flash attention
I ported over the flash attention code changes form #221 . ~This does not yet fix the existing code behind the define, which broke when gg removed the code from ggml assuming noone was using it :)~ ~The old broken code behind the define also only gets invoked by the VAE.~ The new code gets used by UNET, DiT and clip/t5. However I have not come across any clip or t5 that just works without extra work, but there are sizable memory reductions for ~sd1,~ sd2, sdxl and flux compute buffers, as well as a speed increase on cuda. I switched VAE over to using the new attention code, but flash attention makes not difference, so it is disabled there. There is alot more left on the table if we employ padding, so flash attention can be applied to more ops.
Flash attention for diffusion models is available via a runtime flag --diffusion-fa.
flux1-schnell performance numbers
all tests done on a debug build, so the real number are likely a bit better. 4 steps, cfg-scale of 1
CPU
512x512:
compute buffer size: 397MB -> 247MB speed q3_k: 80.00s/it -> 83.38s/it
768x768:
compute buffer size: 1103MB -> 503MB speed q3_k: 197.11s/it -> 209.16s/it
CUDA
512x512:
compute buffer size: 398MB -> 248MB speed q3_k: 1.98s/it -> 1.69s/it speed q4_k: 1.84s/it -> 1.54s/it
768x768:
compute buffer size: 1105MB -> 505MB speed q3_k: 4.58s/it -> 3.58s/it speed q4_k: OOM -> 3.29s/it
direct comparison of the cuda images:
| nofa | fa |
|---|---|
difference, darker/colorful is worse:
| error | error^2 |
|---|---|
SD2 turbo
8 steps, cfg-scale of 1
CPU
512x512:
compute buffer size: 367MB -> 130MB speed q8_0: 6.40s/it -> 6.65s/it
768x768:
compute buffer size: 1718MB -> 294MB speed q8_0: 20.59s/it -> 22.41s/it
CUDA
512x512:
compute buffer size: 367MB -> 130MB speed q8_0: 6.24it/s -> 8.17it/s
768x768:
compute buffer size: 1718MB -> 295MB speed q8_0: 1.84it/s -> 3.17it/s
direct comparison of the cuda images:
| nofa | fa |
|---|---|
difference, darker/colorful is worse:
| error | error^2 |
|---|---|
SDXL realvisxl_v50Lightning
6 steps, cfg-scale of 1.8 dpm++2mv2 karras
CPU
512x512:
compute buffer size: 131MB -> 131MB speed q8_0: 15.04s/it -> 15.21s/it
768x768:
compute buffer size: 330MB -> 280MB speed q8_0: 37.36s/it -> 39.30s/it
CUDA
512x512:
compute buffer size: 132MB -> 132MB speed q8_0: 2.01it/s -> 2.36it/s
768x768:
compute buffer size: 331MB -> 280MB speed q8_0: 1.23s/it -> 1.15s/it
direct comparison of the cuda images:
| nofa | fa |
|---|---|
difference, darker/colorful is worse:
| error | error^2 |
|---|---|
There still is the opportunity to pad some tensors to make them fit.
TODO
- [x] remove the define
- [x] add with a runtime switch
- [x] diffusion model
- [ ] ~vae~
- [x] more exhaustive testing with supported models
- [x] add more then just flux numbers to the op
- [x] add image comparisons
- [x] update docs
Please test this code.
props to @FSSRepo for having the code laying around
fixes: https://github.com/leejet/stable-diffusion.cpp/issues/297
images: sd2_turbo.zip flux1-schnell-q4_k.zip flux1-schnell-q3_k.zip
update: added sd2 and sdxl numbers
update2: see rocm tests down in the thread. it works but behaves similar to cpu
udpate3: added flash_attn param and expose via --diffusion-fa runtime flag for supported models.
@Green-Sky -- would this enable flash attention for Vulkan builds as well?
@Green-Sky -- would this enable flash attention for Vulkan builds as well?
No, sadly not. Vulkan does not implement GGML_OP_FLASH_ATTN_EXT.
However, it looks like cuda, cuda built as rocm (and musa?) and metal all support it.
Would be cool if someone could try rocm and metal builds.
I just tested SD3 (2B), and it appears the kv dimensions are not multiples of 256, so flash attention wont work on the mmdit without padding.
I wanted to test the PR with the ROCM backend but I couldn't get it to build. Did you somehow add back the missing ggml_flash_attn function that is used in ggml_extend.hpp? I am not familiar with ggml so it would be nice if you could explain how you got the CUDA backend working.
I wanted to test the PR with the ROCM backend but I couldn't get it to build. Did you somehow add back the missing
ggml_flash_attnfunction that is used inggml_extend.hpp? I am not familiar with ggml so it would be nice if you could explain how you got the CUDA backend working.
Oh no, sorry for the confusion, dont use the old define or cmake option, I did not add that back in. Just build it as-is and it has flash attention (for diffusion models only) enabled. The old code that gets enabled with the define is what would get used by VAE, but I have not touched that part yet.
edit: I might push a change soon, where you can enable flash attention for diffusion models via a command line option.
Oh interesting. Because I already tried a comparison between the current master and your PR applied without the cmake flag. And the results didn't look like it was enabled. Here are the results (no Flash Attention is how I marked that I didn't enable the flag). I used the flux1-schnell model using q8_0.
Master - no Flash Attention
total params memory size = 21471.11MB (VRAM 12152.28MB, RAM 9318.83MB) sampling completed, taking 149.66s - 18.69s/it
PR - no Flash Attention
total params memory size = 21471.11MB (VRAM 12152.28MB, RAM 9318.83MB) sampling completed, taking 173.13s - 21.62s/it
@MineGame159 I see, so speed went down on rocm...
Also, run it with -v and look for
[DEBUG] ggml_extend.hpp:739 - attention_ext L_q:2304 L_k:2304 n_head:24 C:3072 d_head:128 N:1
[DEBUG] ggml_extend.hpp:763 - using flash attention
If an attention_ext line with tensor sizes is follows by using flash attention, then that specific attention was converted to a flash attention.
Also look for flux compute buffer size: 456.75 MB(VRAM) (in the spam with -v) to see how the compute buffer is affected. Total params size does not change with this patch. :smiley:
Oh sorry 😅, I am pretty new to sd.cpp and AI as a whole. And yes, I can see that it is using flash attention.
[DEBUG] ggml_extend.hpp:738 - attention_ext L_k:1280 n_head:24 C:3072 d_head:128
[DEBUG] ggml_extend.hpp:755 - using flash attention
And the compute buffer size: Master - flux compute buffer size: 398.50 MB(VRAM) PR - flux compute buffer size: 248.50 MB(VRAM)
And the resulting image look close to identical?
The numbers look good, the compute buffer size reduction is the same as on cuda. Try larger images :)
Tested with a 768x768 image and I can't see any difference. The size of the compute buffer is basically the same as on CUDA for you. Only thing is the decreased speed.
1105.07 MB(VRAM) -> 505.07 MB(VRAM) 41.66s/it -> 57.59s/it
I diffed your images.
| error | error^2 |
|---|---|
The error is even less then what I have with cuda.
Tested with a 768x768 image and I can't see any difference. The size of the compute buffer is basically the same as on CUDA for you. Only thing is the decreased speed.
Yea, good to know, thanks for testing. In summery: use it if you use cuda and/or need the extra vram savings.
edit: I might push a change soon, where you can enable flash attention for diffusion models via a command line option.
I did. I am now looking into vae.
update: flash attention does not seem to improve VAE in anyway + it does not work on cuda, since cuda supports d_head up to 256, but VAE needs 512. So I think I will not enable VAE flash attention, but still remove the old crusty code.
Would be nice if someone could test metal.
- Accelerated memory-efficient CPU inference
- Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image, enabling Flash Attention just requires ~1.8GB.
imma leave @leejet to update this example, not sure where those numbers came from.
If someone wants to play with kv-padding, to enable more cases where flash attention can be used, here is a patch:
diff --git a/ggml_extend.hpp b/ggml_extend.hpp
index 8452a0b..518eb6f 100644
--- a/ggml_extend.hpp
+++ b/ggml_extend.hpp
@@ -710,6 +710,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
float scale = (1.0f / sqrt((float)d_head));
+ int kv_pad = 0;
//if (flash_attn) {
// LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N);
//}
@@ -717,7 +718,14 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0));
bool can_use_flash_attn = true;
- can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
+ if (can_use_flash_attn && L_k % 256 != 0) {
+ if (L_k == 77) {
+ kv_pad = 256 - (L_k % 256);
+ } else {
+ can_use_flash_attn = false;
+ }
+ }
+ //can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0;
can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check
// cuda max d_head seems to be 256, cpu does seem to work with 512
@@ -734,11 +742,18 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context*
ggml_tensor* kqv = nullptr;
//GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn);
if (can_use_flash_attn && flash_attn) {
- //LOG_DEBUG("using flash attention");
+ LOG_DEBUG("using flash attention");
+ if (kv_pad != 0) {
+ LOG_DEBUG("padding kv by %d", kv_pad);
+ k = ggml_pad(ctx, k, 0, kv_pad, 0, 0);
+ }
k = ggml_cast(ctx, k, GGML_TYPE_F16);
v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head]
v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head]
+ if (kv_pad != 0) {
+ v = ggml_pad(ctx, v, 0, kv_pad, 0, 0);
+ }
v = ggml_cast(ctx, v, GGML_TYPE_F16);
kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0);
I wont be committing this here since the memory saving are minuscule, or it uses even more memory because of the padding.
Plus, the images look way off, so maybe there is a mistake somewhere or it's just the f16 cast in the wrong situation...
rebased on master, did not yet look at sd3.5, it is disabled, same as sd3 for now. I still consider this ready for merge as-is, I will probably look at sd3.5 in the coming weeks.
Thank you for your contribution
~~In the flux model I used, FA doesn't seem to be on even though using -diffusion-fa. The reason is L_k=4208 and L_k % 256 != 0. Any idea how to fix it?~~ Never mind. I figured out. In case any one has the same issue, try the following on top of @Green-Sky 's patch:
if (can_use_flash_attn && L_k % 256 != 0) {
+ if (L_k == 77 || L_k == 4208) {
+ kv_pad = 256 - (L_k % 256);
+ } else {
+ can_use_flash_attn = false;
+ }
+ }
At 832x1216 resolution, diffusion steps went from 30s (FA off) down to 18s (FA on) on my 4090. Big speed up. The quality may have been degraded a bit, but still ok.
@Green-Sky, I think you should commit the patch. It helps with users who want to gain some speed and save VRAM.
That is nice to see. I will try it and test against other models to catch potential degradation. :)
The quality may have been degraded a bit, but still ok. It helps with users who want to gain some speed and save VRAM.
Yea, this usually allows for larger images or higher quants. Good tradeoff.
Found another L_k for SDXL
if (L_k == 77 || L_k == 4208 || L_k == 3952) {
@Green-Sky Wondering why you didn't enable FA for SD1.x. Did you face any particular issue?
I honestly don't remember. I and apparently did not talk about it here. Also, since then, upstream flash attention changed, more shapes have become compatible for some backends AND vulkan gained working flash attention. :) Is there still such a demand for SD1.x ? I can investigate if you want.
I'm planning to add FA support to the OpenCL backend in the future. Since the maximum buffer size in OpenCL is quite limited on mobile GPUs (around 800–1024 MB on the latest Adreno), the compute buffer will become a precious resource, especially when adding support for older GPUs, which likely have even less available buffer memory. But that's still a long way off, so no hurry :)
@rmatif Ok, sd1 has a d_heads of 40, 80 and 160, while ggml flash attention cuda kernels require a multiple of 64. At least that's what my understanding was at the time. But I think there are some cases that can just work. (looking at it rn)
In fact, cuda and vulkan fa seem to support:
- 64
- 80
- 96
- 112
- 128
- 192
- 256
I enabled the 80 case, which does not reduce vram usage, but makes the cuda sampling slightly faster (83.17s -> 79.20s) It is not much since its just 5 out of 32 attention operations.
Details
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:12288 L_k:12288 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:12288 L_k:77 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:12288 L_k:12288 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:12288 L_k:77 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:3072 L_k:3072 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:874 - -> uses flash attention
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:3072 L_k:77 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:3072 L_k:3072 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:874 - -> uses flash attention
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:3072 L_k:77 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:768 L_k:768 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:768 L_k:77 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:768 L_k:768 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:768 L_k:77 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:192 L_k:192 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:192 L_k:77 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:768 L_k:768 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:768 L_k:77 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:768 L_k:768 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:768 L_k:77 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:768 L_k:768 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:768 L_k:77 n_head:8 C:1280 d_head:160 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:3072 L_k:3072 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:874 - -> uses flash attention
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:3072 L_k:77 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:3072 L_k:3072 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:874 - -> uses flash attention
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:3072 L_k:77 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:3072 L_k:3072 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:874 - -> uses flash attention
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:3072 L_k:77 n_head:8 C:640 d_head:80 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:12288 L_k:12288 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:12288 L_k:77 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:12288 L_k:12288 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:12288 L_k:77 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:12288 L_k:12288 n_head:8 C:320 d_head:40 N:1
[DEBUG] ggml_extend.hpp:844 - attention_ext L_q:12288 L_k:77 n_head:8 C:320 d_head:40 N:1
@rmatif embrace your maginal gains for sd1 :) #736
@Green-Sky Thank you very much for your reactivity! I'm almost done adding FA support to OpenCL, so your timing is perfect