Redo custom attention processor to support other attention types
Summary
Current attention processor implements only torch-sdp attention type, so when any ip-adapter or regional prompt used, we override model to run torch-sdp attention.
New attention processor combines 4 attention processors(normal, sliced, xformers, torch-sdp) by moving parts of attention that differs(mask preparation and attention itself), to separate function call, where required implementation executed.
Related Issues / Discussions
None
QA Instructions
Change attention_type in invokeai.yaml and then run generation with ip-adapter or regional prompt.
Merge Plan
None?
Checklist
- [x] The PR has a short but descriptive title, suitable for a changelog
- [ ] Tests added / updated (if applicable)
- [ ] Documentation added / updated (if applicable)
@dunkeroni @RyanJDick
I haven't looked at the code yet, but do you know if there are still use cases for using attention processors other than Torch 2.0 SDP? Based on the benchmarking that diffusers has done, it seems like the all around best choice. But maybe there are still reasons to use other implementation e.g. very-low-vram system?
I thought roughly same:
normal - generally no need in it
xformers - if you said that torch-sdp on par or even faster, then too can be removed
sliced - yes it's suitable for low memory situations, and I think it's main attention for mps
On CUDA, torch's SDP was faster than xformers for me when I last checked a month or so back. IIRC it was just a couple % faster.
I thought about this some more, and I'm hesitant to proceed with trying to merge this until we have more clarity around which attention implementations we actually want to support.
Right now, we have _adjust_memory_efficient_attention, which tries to configure attention based on the config and the system properties. The logic in this function is outdated, and I think there has been hesitation to change it out of fear of causing a regression on some systems. Let's get to the bottom of this, before deciding how to proceed with this PR.
My current guess is that just supporting torch SDP and sliced attention would cover all use cases. But, we need to do some testing to determine if this is accurate.
A few data points to consider:
- https://pytorch.org/blog/accelerated-diffusers-pt-20/
- https://huggingface.co/docs/diffusers/v0.13.0/en/optimization/torch2.0
- I did a quick experiment on an RTX4090 and saw a speedup from choosing SDP over xformers (not currently the default behaviour).
@StAlKeR7779 do you want to look into this?
@RyanJDick ok, I removed normal and xformers attentions.
But some parts related to frontend, so I hope that @psychedelicious will look at it.
Also question is - what to do with config, I think old configs can be migrated, to convert normal/xformers values, but I not familiar with this part, will look closely later. Maybe @lstein can suggest how to do it.
Upd: Already added config migration, but welcome to hear if done smth wrong in it.
Also I can confirm that both in this PR and in main, when rendering a 1024x1024 image with an SDXL model, slice sizes, of 3, 6, 7 and 8 produce mangled images. I guess this is the floor division error in diffusers that psychedelicious flagged.
Also I can confirm that both in this PR and in
main, when rendering a 1024x1024 image with an SDXL model, slice sizes, of 3, 6, 7 and 8 produce mangled images. I guess this is the floor division error indiffusersthat psychedelicious flagged.
I patched this moment in our attention processor, so that if you add any ip adapter or regional prompt then you should get normal output.
Branch with our attention processor currently executes only if we have ip adapter or regional prompt. We can change it by removing _adjust_memory_efficient_attention, but I unsure if it should be done in this PR and if it should be done at all as a bit later this code will be removed at all when modular backend will be completed.
Returned back normal attention as tests show that it works on some mps systems faster than torch-sdp.
Also 'merge' sliced attention in this attentions, so that we select attention_type=normal/torch-sdp and then can add attention_slice_size if we want do sliced attention. This sliced attention will use for each slice algo from attention_type, previously it works with algo from normal attention.
pls run scripts/update_config_docstring.py to update the config docstrings
How critical is it to remove xformers completely vs leaving it as an option?
torch-sdp is 4.8x slower on older generation Pascal GPUs, like P40 or P100 or the Nvidia 10xx GPUs. For SDXL, this means 8.8 vs 1.8 seconds/iter
Other than that, torch-sdp is slightly faster on Ampere and Turing. Likely also on Ada, but i haven't tested that one.
As I said - it cost nothing to support it further in code. But looks like most peoples thought it should be removed, so I removed.
I can easily return it.
Should we select attention type based on cuda compute capability that starting from 7.5 default is torch-sdp and for older ones default xformers(if available)?
That seems reasonable, with a configurable override if user wants to force one.
Does look like we should add it back.
The PyTorch blog says Flash Attention is supported from sm80 compute capability onwards: https://pytorch.org/blog/accelerated-pytorch-2/, so perhaps we should default to xformers for anything lower than that. (As @hipsterusername said, with a configurable override).
Not for this PR, but I did some performance testing and we'll probably want to address this at some point:
SDXL:
>>> Time taken to prepare attention processors: 0.10069823265075684s
>>> Time taken to prepare attention processors: 0.07877492904663086s
>>> Time taken to set attention processors: 0.1278061866760254s
>>> Time taken to reset attention processors: 0.13225793838500977s
Code used to measure:
def apply_custom_attention(self, unet: UNet2DConditionModel):
"""A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
start = time.time()
attn_procs = self._prepare_attention_processors(unet)
time_1 = time.time()
print(f">>> Time taken to prepare attention processors: {time_1 - start}s")
orig_attn_processors = unet.attn_processors
time_2 = time.time()
print(f">>> Time taken to prepare attention processors: {time_2 - time_1}s")
try:
# Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
# the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
# moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
unet.set_attn_processor(attn_procs)
time_3 = time.time()
print(f">>> Time taken to set attention processors: {time_3 - time_2}s")
yield None
finally:
time_4 = time.time()
unet.set_attn_processor(orig_attn_processors)
time_5 = time.time()
print(f">>> Time taken to reset attention processors: {time_5 - time_4}s")
It looks like there was a significant re-write of the attention logic after the latest round of review and testing on this PR. @StAlKeR7779 can you shed some light on the benefits / motivation for that latest re-write?
Given the amount of testing and discussion that went into this branch before the re-write, I'm wondering if we should revert those latest changes, merge this PR, and then open a new PR with the changes. Let me know what you think.