OneTrainer icon indicating copy to clipboard operation
OneTrainer copied to clipboard

[Feat]: Attention backend selection for Diffusers

Open zzlol63 opened this issue 2 months ago • 15 comments

Describe your use-case.

The latest version of Diffusers supports being able to configure or select a specific attention backend such as FlashAttention-2/FlashAttention-3 (which supports backward pass).

OneTrainer could potentially benefit in performance if this was togglable using the provided API as per the documentation below using a context manager: https://huggingface.co/docs/diffusers/main/en/optimization/attention_backends

with attention_backend("flash"):
    # training/inference goes here

It is also possible to force OneTrainer to use a particular attention backend using the following environment variable (example using FA2): DIFFUSERS_ATTN_BACKEND=flash

The feature itself is marked as experimental and still have yet to find out the potential drawbacks of using this for training. So far I have observed that for training a Chroma LoRA with batch size 3 on an RTX 5090, the seconds per iteration reduced from 9.5s/it to 6.4s/it on my machine.

There are checks associated with each of the backends (ie. shape, dtypes, etc) which are skipped by default. Depending on selected training types in OneTrainer, stuff can easy break if not aligned. It's possible to enable these checks in Diffusers, but that incurs extra overhead per attention-call (but can be good as a quick sanity check to validate the current configuration is sane): DIFFUSERS_ATTN_CHECKS=1

I tried enabling this flag and it failed one of the assertions: Attention mask must match the key's second to last dimension.

Probably needs some more investigation.

What would you like to see as a solution?

Look into potentially adding a dropdown to select supported attention backends (or documenting something maybe and adding the necessary caveats or tested configurations/models).

This feature is marked as experimental currently by HuggingFace.

Have you considered alternatives? List them here.

No response

zzlol63 avatar Nov 02 '25 10:11 zzlol63

Preliminary investigation for Chroma suggests the shape of the attention_mask does not match what is expected and hence leading to the validation error: https://github.com/huggingface/diffusers/blob/8f80dda193f79af3ccd0f985906d61123d69df08/src/diffusers/models/transformers/transformer_chroma.py#L256

The issue can also be replicated with the standard SDPA backend and is not specific to FA:

DIFFUSERS_ATTN_BACKEND=native
DIFFUSERS_ATTN_CHECKS=1

Issue has been raised in diffusers for further investigation there: https://github.com/huggingface/diffusers/issues/12575

zzlol63 avatar Nov 02 '25 10:11 zzlol63

After doing some research, standard FlashAttention does not support attention masking out of the box which is required by some models to implement functionality such as caption dropout. This will be a problem if we intend to use it as a backend. https://github.com/Dao-AILab/flash-attention/issues/352

PyTorch have another newer backend called FlexAttention which supports custom attention masking with the performance of FlashAttention-2 which makes it a much more attractive option. The diffusers attention backends already supports FlexAttention out of the box which is great - but I have yet to get it working on my machine.

EDIT - After some tinkering, I had to use torch.compile first on flex_attention (with max-autotune-no-cudagraphs to resolve OOM issue with triton) and then also upgrade PyTorch to 2.9 and triton to the latest version due to bug with PyTorch 2.7.x (also confirmed that PyTorch 2.8 works fine with a compatible version of triton).

DIFFUSERS_ATTN_BACKEND=flex

import torch
from torch.nn.attention import flex_attention

# Compile flex_attention in-place due to not being done by diffusers
flex_attention.flex_attention = torch.compile(flex_attention.flex_attention, mode="max-autotune-no-cudagraphs", dynamic=True, fullgraph=True)

# Compile create_block_mask for minor speedup + VRAM reduction for attention masks
flex_attention.create_block_mask = torch.compile(flex_attention.create_block_mask, mode="max-autotune-no-cudagraphs", dynamic=True, fullgraph=True)

I was facing several issues enabling cudagraphs hence it's disabled above in the monkeypatch.

With above changes, time per step from my repro reduces from 9.5it/s → 6.7it/s.

zzlol63 avatar Nov 02 '25 12:11 zzlol63

How do I install flash attn? pip install flash_attn fails with

 Installing build dependencies ... done
  Getting requirements to build wheel ... error
  error: subprocess-exited-with-error
  
  × Getting requirements to build wheel did not run successfully.
  │ exit code: 1
  ╰─> [20 lines of output]
      Traceback (most recent call last):
        File ".../OneTrainer/venv/lib/python3.12/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 389, in <module>
          main()
        File ".../OneTrainer/venv/lib/python3.12/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 373, in main
          json_out["return_val"] = hook(**hook_input["kwargs"])
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File ".../OneTrainer/venv/lib/python3.12/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 143, in get_requires_for_build_wheel
          return hook(config_settings)
                 ^^^^^^^^^^^^^^^^^^^^^
        File "/tmp/pip-build-env-tir6y4yc/overlay/lib/python3.12/site-packages/setuptools/build_meta.py", line 331, in get_requires_for_build_wheel
          return self._get_build_requires(config_settings, requirements=[])
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        File "/tmp/pip-build-env-tir6y4yc/overlay/lib/python3.12/site-packages/setuptools/build_meta.py", line 301, in _get_build_requires
          self.run_setup()
        File "/tmp/pip-build-env-tir6y4yc/overlay/lib/python3.12/site-packages/setuptools/build_meta.py", line 512, in run_setup
          super().run_setup(setup_script=setup_script)
        File "/tmp/pip-build-env-tir6y4yc/overlay/lib/python3.12/site-packages/setuptools/build_meta.py", line 317, in run_setup
          exec(code, locals())
        File "<string>", line 22, in <module>
      ModuleNotFoundError: No module named 'torch'
      [end of output]
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed to build 'flash_attn' when getting requirements to build wheel

even though OneTrainer venv is activated.

Installing as recommend on their page python setup.py install runs out of memory, with 64 GB ram and 32 GB swap

dxqb avatar Nov 02 '25 21:11 dxqb

Try with build-isolation disabled and make sure to restrict the MAX_JOBS used by ninja (ensure package is installed prior for faster build) or you will OOM your system: MAX_JOBS=4 pip install . --no-build-isolation

zzlol63 avatar Nov 02 '25 21:11 zzlol63

now it complains that my nvcc is too old I could probably solve this somehow - and thank you - but in order to use it for OneTrainer it needs to be simple to install.

Is there a way to install it using pip with no compilation or special options?

dxqb avatar Nov 02 '25 21:11 dxqb

@dxqb There are a few repos which offer precompiled wheels, but only certain PyTorch+CUDA+Python configurations are supported.

Below is an example for Linux: https://github.com/mjun0812/flash-attention-prebuild-wheels

I think we should focus on getting FlexAttention into OneTrainer though (at least that is baked into PyTorch itself), FlashAttention might be going down a rabbit-hole which is not a good option to start with.

zzlol63 avatar Nov 02 '25 22:11 zzlol63

Also, I fully trained a Chroma LoRA last night using masked training + prior preservation with FlexAttention (as per https://github.com/Nerogar/OneTrainer/issues/1090#issuecomment-3477943415) and results turned out fine - model converged as expected and worked fine in ComfyUI.

zzlol63 avatar Nov 02 '25 23:11 zzlol63

benchmarks show that torch SDP is already as fast as flash attention:

flash: 100%|███████████████████████████████████████████████████████████████████████████| 5000/5000 [00:25<00:00, 196.18it/s]
torch SDP: 100%|███████████████████████████████████████████████████████████████████████| 5000/5000 [00:26<00:00, 190.36it/s]
torch SDP no flash: 100%|██████████████████████████████████████████████████████████████| 5000/5000 [00:40<00:00, 122.42it/s]

only if you specifically instruct torch to not use the flash algorithm, it's slower (last benchmark line): torch.backends.cuda.enable_flash_sdp(False)

I have observed that for training a Chroma LoRA with batch size 3 on an RTX 5090, the seconds per iteration reduced from 9.5s/it to 6.4s/it on my machine.

Please reproduce this step by step. If the above is correct, this cannot have been flash attention on its own

dxqb avatar Nov 03 '25 17:11 dxqb

As discussed on Discord, the issue appears specific to Windows where Torch SDP backend cannot use the native FlashAttention-2 based kernel as it's not compiled with FlashAttention support in the first place.

Below are benchmark results using different attention implementations that have been torch.compile'd on a dummy input:

Image

Attempting to force flash attention backend with PyTorch leads to an error on Windows:

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
    fn = lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None)
    run_benchmark(fn, "torch SDP flash")
C:\OneTrainer\venv\Lib\site-packages\torch\_dynamo\utils.py:3546: UserWarning: Flash attention kernel not used because: (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:917.)
  return node.target(*args, **kwargs)  # type: ignore[operator]
C:\OneTrainer\venv\Lib\site-packages\torch\_dynamo\utils.py:3546: UserWarning: Torch was not compiled with flash attention. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\aten\src\ATen\native\transformers\cuda\sdp_utils.cpp:726.)

There is an open issue at the pytorch repo regarding support: https://github.com/pytorch/pytorch/issues/108175

This likely explains why switching the attention backend to flex sees a performance boost on my Windows machine.

zzlol63 avatar Nov 04 '25 09:11 zzlol63

Did some further testing, this time with a fresh dataset (with optional masks) and testing with a natively booted Linux distro with identical settings in OneTrainer and came back with the following results. Below was specifically targetting the Chroma model.

Image

At least with the default attention backend (Torch SDPA), performance is identical across matching PyTorch versions for both Linux and Windows for this specific workload.

The biggest gains on both Windows and Linux are seen in PyTorch 2.9 with FlexAttention (which replicates my earlier findings).

Note that Chroma uses attention-masks during training, so it cannot use the FlashAttention kernel for Torch SDP.

I will run some more tests on a different attention-heavy model, that doesn't use attention masks such as Flux.

zzlol63 avatar Nov 05 '25 13:11 zzlol63

I was now able to reproduce this:

1024 px bs 3
RTX 5090
2.9.0	7.0	  5.0

7 s/it native vs. 5.0 s/it with flex attention

It does generally not depend on 2.9.0 but also works with 2.8.0 (but not with 2.7.1):

2.8.0	7.1	  5.2

2.8.0 is generally a bit slower, but close enough. All my further tests will be with 2.8.0 because I experienced some bugs with 2.9.0 and diffusers.

The large speed-up of about 25% seems to depend on quite special conditions though.

  1. on my 4070 the speed-up is only about 13% instead of 25%
RTX 4070		
1024 px		
2.8.0	16.1	14.0
  1. it depends on a large sequence length

With 512 px bs 3, there is maybe 10% speed-up on a 5090 (difficult to say), and virtually none on my 4070:

RTX 5090
512 px		
2.8.0	1.2	   1.1
RTX 4070
512 px		
2.8.0	3.1	 3.0

This combination is why I wasn't able to reproduce it:

  • no-speed-up on 2.7.1
  • no speed-up on smaller sequence lengths on my 4070

I would therefore conclude that flex attention can be more efficient probably because it is compiled and therefore specialized at runtime for whatever sequence length is used, while torch SDP cannot do that - but overall the algorithm is likely the same.

Still not bad for those conditions where it helps and we could use it once it is stable.

However currently, somewhat surprisingly, it cannot be combined with using torch.compile overall. This PR https://github.com/Nerogar/OneTrainer/pull/1034 with flex attention fails. I'll open a torch issue about this.

Therefore currently we could only choose one or the other - but using torch.compile for everything without flex is much faster then using it only for flex:

RTX 4070			
1024 px			
2.8.0	16.1	14.0	12.5
			
512 px bs3			
2.8.0	3.1	   3.0  	2.2

Third number is only using torch.compile, not the other features of the PR.

dxqb avatar Nov 05 '25 19:11 dxqb

https://github.com/pytorch/pytorch/issues/167116

dxqb avatar Nov 05 '25 20:11 dxqb

I ran a set of tests using FLUX.1-dev on the same dataset. I did post some numbers previously but realised I made a huge mistake where the FlexAttention backend wasn't actually active. Below are the results:

Image

As far as standard Torch SDP, Linux out performs Windows in every scenario, likely due to having the FlashAttention kernel precompiled into the binaries, whereas even with the latest PyTorch that is not the case for Windows.

However, with PyTorch 2.9, using FlexAttention on Windows pretty much closes that performance gap.

zzlol63 avatar Nov 06 '25 06:11 zzlol63

However, with PyTorch 2.9, using FlexAttention on Windows pretty much closes that performance gap.

Using FlexAttention is currently not an option, and probably not in the near future. Torch calls it a "prototype feature" and there are probably more open issues with it than the one we've found above.

So if on windows, the torch SDP algorithm is much worse, the only alternative would be to use another external flash attention algorithm. For example by using flash_attn (with an easy install method), which is how we started above.

dxqb avatar Nov 06 '25 09:11 dxqb

So if on windows, the torch SDP algorithm is much worse, the only alternative would be to use another external flash attention algorithm. For example by using flash_attn (with an easy install method), which is how we started above.

Yes, with caveats obviously, it is only suitable for use for certain models only.

zzlol63 avatar Nov 06 '25 09:11 zzlol63

from Discord:

I actually get significant speedup from [flash varlen attention], but I had to manipulate the tensors a bit to get it to pick the optimized kernel. I also made sure I was precomputing the masks per batch rather than naively doing it per layer. This is a bit messy, but should give you the general concept

https://gist.github.com/deepdelirious/ed1dcbc131ce6cad7a5a196c6c09aa24

dxqb avatar Dec 04 '25 07:12 dxqb