SageAttention
Draft PR: SageAttention
This PR applies SageAttention to the Attention Processor in the model’s UNet/Transformer. By introducing SageAttention 2.1.1, we expect to achieve a training speed improvement of approximately 50–75%.
Requirements
- Triton: Must be installed beforehand.
- Building SageAttention:
Build version 2.1.1 from source using the repository at SageAttention GitHub Repository](https://github.com/thu-ml/SageAttention).Note: The pip installation defaults to version 1.0.6, which provides only minimal speed improvements.
- Alternatively, you may choose to install using a volunteer-built wheel (not provided here).
Implementation & Verified Operation
| Model | Implemented | Tested |
|---|---|---|
| SD | ✅ | ❌ |
| SDXL | ✅ | ✅ |
| Pixart Alpha | ✅ | ❌ |
| Flux.1 | ✅ | ❌ |
| Others | ❌ | ❌ |
Todo
- Verify Training Consistency and Tolerance:
(Minor discrepancies are observed during initial inference, but they are aesthetically acceptable.) - Implement Support for Additional Models.
- Update Documentation.
Were you able to get it to train? From what I was hearing on their repo it wasn't supporting autograd and thus not able to be trained with yet. https://github.com/thu-ml/SageAttention/issues/60
I slapped your code on top of current onetrainer and created style lora to see if this works.
Training time is 20 min without and 16 min with SA, 25% speedup. But training result have obviously suffered.
I thought it can be because of masked training so I tried with/without masked training too.
It may be that it needs tuning in other settings to achieve the same quality, or it's just doesn't work for now.
Trained on noobai v-pred v1. um - unmasked, m - masked.
Since the authors of sage attention even say that training is not yet possible with their implementation, I'm marking this as a draft. It doesn't make sense to merge in it's current state
The author of SageAttention appears to plan its backward pass support only for Hopper-architecture GPUs, even if they implement it in the future. I have implemented the same approach using FlashAttention 2 and verified that it works, so please consider that option. It should support backpropagation—though to preserve full precision you might need to enable deterministic mode (sacrificing VRAM), this implementation leaves deterministic=False.
The author of SageAttention appears to plan its backward pass support only for Hopper-architecture GPUs, even if they implement it in the future. I have implemented the same approach using FlashAttention 2 and verified that it works, so please consider that option. It should support backpropagation—though to preserve full precision you might need to enable deterministic mode (sacrificing VRAM), this implementation leaves deterministic=False.
Can you expand a little on this? Does it mean you can use int4 attention over fa2 during entire process?
I'd suggest to close this PR. It is a feature request to SageAttention, not to OneTrainer. Maybe SageAttention will implement it, then we can use it. Or other projects will be faster and SageAttention be outdated by then - for example, nunchaku could implement backward propagation.
I agree. Closing.