InvokeAI icon indicating copy to clipboard operation
InvokeAI copied to clipboard

Redo custom attention processor to support other attention types

Open StAlKeR7779 opened this issue 1 year ago • 15 comments

Summary

Current attention processor implements only torch-sdp attention type, so when any ip-adapter or regional prompt used, we override model to run torch-sdp attention. New attention processor combines 4 attention processors(normal, sliced, xformers, torch-sdp) by moving parts of attention that differs(mask preparation and attention itself), to separate function call, where required implementation executed.

Related Issues / Discussions

None

QA Instructions

Change attention_type in invokeai.yaml and then run generation with ip-adapter or regional prompt.

Merge Plan

None?

Checklist

  • [x] The PR has a short but descriptive title, suitable for a changelog
  • [ ] Tests added / updated (if applicable)
  • [ ] Documentation added / updated (if applicable)

@dunkeroni @RyanJDick

StAlKeR7779 avatar Jun 27 '24 13:06 StAlKeR7779

I haven't looked at the code yet, but do you know if there are still use cases for using attention processors other than Torch 2.0 SDP? Based on the benchmarking that diffusers has done, it seems like the all around best choice. But maybe there are still reasons to use other implementation e.g. very-low-vram system?

RyanJDick avatar Jun 27 '24 14:06 RyanJDick

I thought roughly same: normal - generally no need in it xformers - if you said that torch-sdp on par or even faster, then too can be removed sliced - yes it's suitable for low memory situations, and I think it's main attention for mps

StAlKeR7779 avatar Jun 27 '24 14:06 StAlKeR7779

On CUDA, torch's SDP was faster than xformers for me when I last checked a month or so back. IIRC it was just a couple % faster.

psychedelicious avatar Jun 28 '24 02:06 psychedelicious

I thought about this some more, and I'm hesitant to proceed with trying to merge this until we have more clarity around which attention implementations we actually want to support.

Right now, we have _adjust_memory_efficient_attention, which tries to configure attention based on the config and the system properties. The logic in this function is outdated, and I think there has been hesitation to change it out of fear of causing a regression on some systems. Let's get to the bottom of this, before deciding how to proceed with this PR.

My current guess is that just supporting torch SDP and sliced attention would cover all use cases. But, we need to do some testing to determine if this is accurate.

A few data points to consider:

  • https://pytorch.org/blog/accelerated-diffusers-pt-20/
  • https://huggingface.co/docs/diffusers/v0.13.0/en/optimization/torch2.0
  • I did a quick experiment on an RTX4090 and saw a speedup from choosing SDP over xformers (not currently the default behaviour).

@StAlKeR7779 do you want to look into this?

RyanJDick avatar Jul 04 '24 18:07 RyanJDick

@RyanJDick ok, I removed normal and xformers attentions. But some parts related to frontend, so I hope that @psychedelicious will look at it. Also question is - what to do with config, I think old configs can be migrated, to convert normal/xformers values, but I not familiar with this part, will look closely later. Maybe @lstein can suggest how to do it. Upd: Already added config migration, but welcome to hear if done smth wrong in it.

StAlKeR7779 avatar Jul 28 '24 00:07 StAlKeR7779

Also I can confirm that both in this PR and in main, when rendering a 1024x1024 image with an SDXL model, slice sizes, of 3, 6, 7 and 8 produce mangled images. I guess this is the floor division error in diffusers that psychedelicious flagged.

lstein avatar Aug 03 '24 13:08 lstein

Also I can confirm that both in this PR and in main, when rendering a 1024x1024 image with an SDXL model, slice sizes, of 3, 6, 7 and 8 produce mangled images. I guess this is the floor division error in diffusers that psychedelicious flagged.

I patched this moment in our attention processor, so that if you add any ip adapter or regional prompt then you should get normal output. Branch with our attention processor currently executes only if we have ip adapter or regional prompt. We can change it by removing _adjust_memory_efficient_attention, but I unsure if it should be done in this PR and if it should be done at all as a bit later this code will be removed at all when modular backend will be completed.

StAlKeR7779 avatar Aug 03 '24 14:08 StAlKeR7779

Returned back normal attention as tests show that it works on some mps systems faster than torch-sdp. Also 'merge' sliced attention in this attentions, so that we select attention_type=normal/torch-sdp and then can add attention_slice_size if we want do sliced attention. This sliced attention will use for each slice algo from attention_type, previously it works with algo from normal attention.

StAlKeR7779 avatar Aug 03 '24 23:08 StAlKeR7779

pls run scripts/update_config_docstring.py to update the config docstrings

psychedelicious avatar Aug 04 '24 07:08 psychedelicious

How critical is it to remove xformers completely vs leaving it as an option?

torch-sdp is 4.8x slower on older generation Pascal GPUs, like P40 or P100 or the Nvidia 10xx GPUs. For SDXL, this means 8.8 vs 1.8 seconds/iter

Other than that, torch-sdp is slightly faster on Ampere and Turing. Likely also on Ada, but i haven't tested that one.

ebr avatar Aug 07 '24 14:08 ebr

As I said - it cost nothing to support it further in code. But looks like most peoples thought it should be removed, so I removed. I can easily return it. Should we select attention type based on cuda compute capability that starting from 7.5 default is torch-sdp and for older ones default xformers(if available)? image

StAlKeR7779 avatar Aug 07 '24 14:08 StAlKeR7779

That seems reasonable, with a configurable override if user wants to force one.

Does look like we should add it back.

hipsterusername avatar Aug 07 '24 15:08 hipsterusername

The PyTorch blog says Flash Attention is supported from sm80 compute capability onwards: https://pytorch.org/blog/accelerated-pytorch-2/, so perhaps we should default to xformers for anything lower than that. (As @hipsterusername said, with a configurable override).

ebr avatar Aug 07 '24 15:08 ebr

Not for this PR, but I did some performance testing and we'll probably want to address this at some point:

SDXL:

>>> Time taken to prepare attention processors: 0.10069823265075684s
>>> Time taken to prepare attention processors: 0.07877492904663086s
>>> Time taken to set attention processors: 0.1278061866760254s
>>> Time taken to reset attention processors: 0.13225793838500977s

Code used to measure:

    def apply_custom_attention(self, unet: UNet2DConditionModel):
        """A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers."""
        start = time.time()
        attn_procs = self._prepare_attention_processors(unet)
        time_1 = time.time()
        print(f">>> Time taken to prepare attention processors: {time_1 - start}s")
        orig_attn_processors = unet.attn_processors
        time_2 = time.time()
        print(f">>> Time taken to prepare attention processors: {time_2 - time_1}s")

        try:
            # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from
            # the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a
            # moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`.
            unet.set_attn_processor(attn_procs)
            time_3 = time.time()
            print(f">>> Time taken to set attention processors: {time_3 - time_2}s")
            yield None
        finally:
            time_4 = time.time()
            unet.set_attn_processor(orig_attn_processors)
            time_5 = time.time()
            print(f">>> Time taken to reset attention processors: {time_5 - time_4}s")

RyanJDick avatar Aug 08 '24 14:08 RyanJDick

It looks like there was a significant re-write of the attention logic after the latest round of review and testing on this PR. @StAlKeR7779 can you shed some light on the benefits / motivation for that latest re-write?

Given the amount of testing and discussion that went into this branch before the re-write, I'm wondering if we should revert those latest changes, merge this PR, and then open a new PR with the changes. Let me know what you think.

RyanJDick avatar Sep 16 '24 15:09 RyanJDick