OneTrainer icon indicating copy to clipboard operation
OneTrainer copied to clipboard

SageAttention

Open celll1 opened this issue 8 months ago • 4 comments

Draft PR: SageAttention

This PR applies SageAttention to the Attention Processor in the model’s UNet/Transformer. By introducing SageAttention 2.1.1, we expect to achieve a training speed improvement of approximately 50–75%.

Requirements

  • Triton: Must be installed beforehand.
  • Building SageAttention:
    Build version 2.1.1 from source using the repository at SageAttention GitHub Repository](https://github.com/thu-ml/SageAttention).

    Note: The pip installation defaults to version 1.0.6, which provides only minimal speed improvements.

  • Alternatively, you may choose to install using a volunteer-built wheel (not provided here).

Implementation & Verified Operation

Model Implemented Tested
SD
SDXL
Pixart Alpha
Flux.1
Others

Todo

  • Verify Training Consistency and Tolerance:
    (Minor discrepancies are observed during initial inference, but they are aesthetically acceptable.)
  • Implement Support for Additional Models.
  • Update Documentation.

celll1 avatar Apr 15 '25 09:04 celll1

Were you able to get it to train? From what I was hearing on their repo it wasn't supporting autograd and thus not able to be trained with yet. https://github.com/thu-ml/SageAttention/issues/60

rockerBOO avatar Apr 18 '25 06:04 rockerBOO

I slapped your code on top of current onetrainer and created style lora to see if this works. Training time is 20 min without and 16 min with SA, 25% speedup. But training result have obviously suffered. I thought it can be because of masked training so I tried with/without masked training too. It may be that it needs tuning in other settings to achieve the same quality, or it's just doesn't work for now. Trained on noobai v-pred v1. um - unmasked, m - masked. xyz_grid-0005-622598234

Gebsfrom404 avatar Apr 20 '25 22:04 Gebsfrom404

Since the authors of sage attention even say that training is not yet possible with their implementation, I'm marking this as a draft. It doesn't make sense to merge in it's current state

Nerogar avatar Apr 21 '25 20:04 Nerogar

The author of SageAttention appears to plan its backward pass support only for Hopper-architecture GPUs, even if they implement it in the future. I have implemented the same approach using FlashAttention 2 and verified that it works, so please consider that option. It should support backpropagation—though to preserve full precision you might need to enable deterministic mode (sacrificing VRAM), this implementation leaves deterministic=False.

celll1 avatar Apr 28 '25 22:04 celll1

The author of SageAttention appears to plan its backward pass support only for Hopper-architecture GPUs, even if they implement it in the future. I have implemented the same approach using FlashAttention 2 and verified that it works, so please consider that option. It should support backpropagation—though to preserve full precision you might need to enable deterministic mode (sacrificing VRAM), this implementation leaves deterministic=False.

Can you expand a little on this? Does it mean you can use int4 attention over fa2 during entire process?

sorasoras avatar May 08 '25 22:05 sorasoras

I'd suggest to close this PR. It is a feature request to SageAttention, not to OneTrainer. Maybe SageAttention will implement it, then we can use it. Or other projects will be faster and SageAttention be outdated by then - for example, nunchaku could implement backward propagation.

dxqb avatar Aug 20 '25 10:08 dxqb

I agree. Closing.

O-J1 avatar Aug 20 '25 10:08 O-J1