[Feat]: Attention backend selection for Diffusers
Describe your use-case.
The latest version of Diffusers supports being able to configure or select a specific attention backend such as FlashAttention-2/FlashAttention-3 (which supports backward pass).
OneTrainer could potentially benefit in performance if this was togglable using the provided API as per the documentation below using a context manager: https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends
with attention_backend("flash"):
# training/inference goes here
It is also possible to force OneTrainer to use a particular attention backend using the following environment variable (example using FA2):
DIFFUSERS_ATTN_BACKEND=flash
The feature itself is marked as experimental and still have yet to find out the potential drawbacks of using this for training. So far I have observed that for training a Chroma LoRA with batch size 3 on an RTX 5090, the seconds per iteration reduced from 9.5s/it to 6.4s/it on my machine.
There are checks associated with each of the backends (ie. shape, dtypes, etc) which are skipped by default. Depending on selected training types in OneTrainer, stuff can easy break if not aligned. It's possible to enable these checks in Diffusers, but that incurs extra overhead per attention-call (but can be good as a quick sanity check to validate the current configuration is sane):
DIFFUSERS_ATTN_CHECKS=1
I tried enabling this flag and it failed one of the assertions:
Attention mask must match the key's second to last dimension.
Probably needs some more investigation.
What would you like to see as a solution?
Look into potentially adding a dropdown to select supported attention backends (or documenting something maybe and adding the necessary caveats or tested configurations/models).
This feature is marked as experimental currently by HuggingFace.
Have you considered alternatives? List them here.
No response
Preliminary investigation for Chroma suggests the shape of the attention_mask does not match what is expected and hence leading to the validation error: https://github.com/huggingface/diffusers/blob/8f80dda193f79af3ccd0f985906d61123d69df08/src/diffusers/models/transformers/transformer_chroma.py#L256
The issue can also be replicated with the standard SDPA backend and is not specific to FA:
DIFFUSERS_ATTN_BACKEND=native
DIFFUSERS_ATTN_CHECKS=1
Issue has been raised in diffusers for further investigation there: https://github.com/huggingface/diffusers/issues/12575
After doing some research, standard FlashAttention does not support attention masking out of the box which is required by some models to implement functionality such as caption dropout. This will be a problem if we intend to use it as a backend. https://github.com/Dao-AILab/flash-attention/issues/352
PyTorch have another newer backend called FlexAttention which supports custom attention masking with the performance of FlashAttention-2 which makes it a much more attractive option. The diffusers attention backends already supports FlexAttention out of the box which is great - but I have yet to get it working on my machine.
EDIT - After some tinkering, I had to use torch.compile first on flex_attention (with max-autotune-no-cudagraphs to resolve OOM issue with triton) and then also upgrade PyTorch to 2.9 and triton to the latest version due to bug with PyTorch 2.7.x (also confirmed that PyTorch 2.8 works fine with a compatible version of triton).
DIFFUSERS_ATTN_BACKEND=flex
import torch
from torch.nn.attention import flex_attention
# Compile flex_attention in-place due to not being done by diffusers
flex_attention.flex_attention = torch.compile(flex_attention.flex_attention, mode="max-autotune-no-cudagraphs", dynamic=True, fullgraph=True)
# Compile create_block_mask for minor speedup + VRAM reduction for attention masks
flex_attention.create_block_mask = torch.compile(flex_attention.create_block_mask, mode="max-autotune-no-cudagraphs", dynamic=True, fullgraph=True)
I was facing several issues enabling cudagraphs hence it's disabled above in the monkeypatch.
With above changes, time per step from my repro reduces from 9.5it/s → 6.7it/s.
How do I install flash attn?
pip install flash_attn
fails with
Installing build dependencies ... done
Getting requirements to build wheel ... error
error: subprocess-exited-with-error
× Getting requirements to build wheel did not run successfully.
│ exit code: 1
╰─> [20 lines of output]
Traceback (most recent call last):
File ".../OneTrainer/venv/lib/python3.12/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 389, in <module>
main()
File ".../OneTrainer/venv/lib/python3.12/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 373, in main
json_out["return_val"] = hook(**hook_input["kwargs"])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../OneTrainer/venv/lib/python3.12/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 143, in get_requires_for_build_wheel
return hook(config_settings)
^^^^^^^^^^^^^^^^^^^^^
File "/tmp/pip-build-env-tir6y4yc/overlay/lib/python3.12/site-packages/setuptools/build_meta.py", line 331, in get_requires_for_build_wheel
return self._get_build_requires(config_settings, requirements=[])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/pip-build-env-tir6y4yc/overlay/lib/python3.12/site-packages/setuptools/build_meta.py", line 301, in _get_build_requires
self.run_setup()
File "/tmp/pip-build-env-tir6y4yc/overlay/lib/python3.12/site-packages/setuptools/build_meta.py", line 512, in run_setup
super().run_setup(setup_script=setup_script)
File "/tmp/pip-build-env-tir6y4yc/overlay/lib/python3.12/site-packages/setuptools/build_meta.py", line 317, in run_setup
exec(code, locals())
File "<string>", line 22, in <module>
ModuleNotFoundError: No module named 'torch'
[end of output]
note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed to build 'flash_attn' when getting requirements to build wheel
even though OneTrainer venv is activated.
Installing as recommend on their page
python setup.py install
runs out of memory, with 64 GB ram and 32 GB swap
Try with build-isolation disabled and make sure to restrict the MAX_JOBS used by ninja (ensure package is installed prior for faster build) or you will OOM your system:
MAX_JOBS=4 pip install . --no-build-isolation
now it complains that my nvcc is too old I could probably solve this somehow - and thank you - but in order to use it for OneTrainer it needs to be simple to install.
Is there a way to install it using pip with no compilation or special options?
@dxqb There are a few repos which offer precompiled wheels, but only certain PyTorch+CUDA+Python configurations are supported.
Below is an example for Linux: https://github.com/mjun0812/flash-attention-prebuild-wheels
I think we should focus on getting FlexAttention into OneTrainer though (at least that is baked into PyTorch itself), FlashAttention might be going down a rabbit-hole which is not a good option to start with.
Also, I fully trained a Chroma LoRA last night using masked training + prior preservation with FlexAttention (as per https://github.com/Nerogar/OneTrainer/issues/1090#issuecomment-3477943415) and results turned out fine - model converged as expected and worked fine in ComfyUI.
benchmarks show that torch SDP is already as fast as flash attention:
flash: 100%|███████████████████████████████████████████████████████████████████████████| 5000/5000 [00:25<00:00, 196.18it/s]
torch SDP: 100%|███████████████████████████████████████████████████████████████████████| 5000/5000 [00:26<00:00, 190.36it/s]
torch SDP no flash: 100%|██████████████████████████████████████████████████████████████| 5000/5000 [00:40<00:00, 122.42it/s]
only if you specifically instruct torch to not use the flash algorithm, it's slower (last benchmark line):
torch.backends.cuda.enable_flash_sdp(False)
I have observed that for training a Chroma LoRA with batch size 3 on an RTX 5090, the seconds per iteration reduced from 9.5s/it to 6.4s/it on my machine.
Please reproduce this step by step. If the above is correct, this cannot have been flash attention on its own
As discussed on Discord, the issue appears specific to Windows where Torch SDP backend cannot use the native FlashAttention-2 based kernel as it's not compiled with FlashAttention support in the first place.
Below are benchmark results using different attention implementations that have been torch.compile'd on a dummy input:
Attempting to force flash attention backend with PyTorch leads to an error on Windows:
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
fn = lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None)
run_benchmark(fn, "torch SDP flash")
C:\OneTrainer\venv\Lib\site-packages\torch\_dynamo\utils.py:3546: UserWarning: Flash attention kernel not used because: (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:917.)
return node.target(*args, **kwargs) # type: ignore[operator]
C:\OneTrainer\venv\Lib\site-packages\torch\_dynamo\utils.py:3546: UserWarning: Torch was not compiled with flash attention. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:726.)
There is an open issue at the pytorch repo regarding support: https://github.com/pytorch/pytorch/issues/108175
This likely explains why switching the attention backend to flex sees a performance boost on my Windows machine.
Did some further testing, this time with a fresh dataset (with optional masks) and testing with a natively booted Linux distro with identical settings in OneTrainer and came back with the following results. Below was specifically targetting the Chroma model.
At least with the default attention backend (Torch SDPA), performance is identical across matching PyTorch versions for both Linux and Windows for this specific workload.
The biggest gains on both Windows and Linux are seen in PyTorch 2.9 with FlexAttention (which replicates my earlier findings).
Note that Chroma uses attention-masks during training, so it cannot use the FlashAttention kernel for Torch SDP.
I will run some more tests on a different attention-heavy model, that doesn't use attention masks such as Flux.
I was now able to reproduce this:
1024 px bs 3
RTX 5090
2.9.0 7.0 5.0
7 s/it native vs. 5.0 s/it with flex attention
It does generally not depend on 2.9.0 but also works with 2.8.0 (but not with 2.7.1):
2.8.0 7.1 5.2
2.8.0 is generally a bit slower, but close enough. All my further tests will be with 2.8.0 because I experienced some bugs with 2.9.0 and diffusers.
The large speed-up of about 25% seems to depend on quite special conditions though.
- on my 4070 the speed-up is only about 13% instead of 25%
RTX 4070
1024 px
2.8.0 16.1 14.0
- it depends on a large sequence length
With 512 px bs 3, there is maybe 10% speed-up on a 5090 (difficult to say), and virtually none on my 4070:
RTX 5090
512 px
2.8.0 1.2 1.1
RTX 4070
512 px
2.8.0 3.1 3.0
This combination is why I wasn't able to reproduce it:
- no-speed-up on 2.7.1
- no speed-up on smaller sequence lengths on my 4070
I would therefore conclude that flex attention can be more efficient probably because it is compiled and therefore specialized at runtime for whatever sequence length is used, while torch SDP cannot do that - but overall the algorithm is likely the same.
Still not bad for those conditions where it helps and we could use it once it is stable.
However currently, somewhat surprisingly, it cannot be combined with using torch.compile overall. This PR https://github.com/Nerogar/OneTrainer/pull/1034 with flex attention fails. I'll open a torch issue about this.
Therefore currently we could only choose one or the other - but using torch.compile for everything without flex is much faster then using it only for flex:
RTX 4070
1024 px
2.8.0 16.1 14.0 12.5
512 px bs3
2.8.0 3.1 3.0 2.2
Third number is only using torch.compile, not the other features of the PR.
https://github.com/pytorch/pytorch/issues/167116
I ran a set of tests using FLUX.1-dev on the same dataset. I did post some numbers previously but realised I made a huge mistake where the FlexAttention backend wasn't actually active. Below are the results:
As far as standard Torch SDP, Linux out performs Windows in every scenario, likely due to having the FlashAttention kernel precompiled into the binaries, whereas even with the latest PyTorch that is not the case for Windows.
However, with PyTorch 2.9, using FlexAttention on Windows pretty much closes that performance gap.
However, with PyTorch 2.9, using FlexAttention on Windows pretty much closes that performance gap.
Using FlexAttention is currently not an option, and probably not in the near future. Torch calls it a "prototype feature" and there are probably more open issues with it than the one we've found above.
So if on windows, the torch SDP algorithm is much worse, the only alternative would be to use another external flash attention algorithm. For example by using flash_attn (with an easy install method), which is how we started above.
So if on windows, the torch SDP algorithm is much worse, the only alternative would be to use another external flash attention algorithm. For example by using flash_attn (with an easy install method), which is how we started above.
Yes, with caveats obviously, it is only suitable for use for certain models only.
from Discord:
I actually get significant speedup from [flash varlen attention], but I had to manipulate the tensors a bit to get it to pick the optimized kernel. I also made sure I was precomputing the masks per batch rather than naively doing it per layer. This is a bit messy, but should give you the general concept
https://gist.github.com/deepdelirious/ed1dcbc131ce6cad7a5a196c6c09aa24