diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Refactor `attention.py`

Open patrickvonplaten opened this issue 2 years ago • 7 comments

attention.py has at the moment two concurrent attention implementations which essentially do the exact same thing:

  • https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/attention.py#L256 and
  • https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/cross_attention.py#L30

Both https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/cross_attention.py#L30 and https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/attention.py#L256 are already used for "simple" attention - e.g. the former for Stable Diffusion and the later for the simple DDPM UNet.

We should start deprecating https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/attention.py#L256 very soon as it's not viable to keep two attention mechanisms.

Deprecating this class won't be easy as it essentially means we have to force people to re-upload their weights. Essentially every model checkpoint that made use of https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/attention.py#L256 has to eventually re-upload their weights to be kept compatible.

I would propose to do this in the following way:

    1. To begin with when https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/attention.py#L256 is called the code will convert the weights on the fly to https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/cross_attention.py#L30 with a very clear deprecation message that explains in detail how to one can save & re-upload the weights to remove the deprecation message
    1. Open a mass PR on all checkpoints that make use of https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/attention.py#L256 (can be retrieved via the config) to convert the weights to the new format.

Happy to help with this PR. @williamberman are you maybe interested in taking this one?

patrickvonplaten avatar Jan 01 '23 18:01 patrickvonplaten

@patrickvonplaten rather than having everybody save & re-upload their weights: can diffusers intercept the weights during model load and map them to different parameter names?

Apple uses PyTorch's _register_load_state_dict_pre_hook() idiom to intercept the weights the state dict is being loaded, transform them and redirect them to be held in different Parameters:
https://github.com/apple/ml-ane-transformers/blob/da64000fa56cc85b0859bc17cb16a3d753b8304a/ane_transformers/huggingface/distilbert.py#L241

however, something about HF's technique for model-loading breaks this idiom. my model loading hooks never get invoked. they work in a CompVis repository, but not inside HF diffusers code. I think something about using importlib to load a .bin skips it. it'd be really good if you could fix that — it's the number one thing that made it difficult for me to optimize the Diffusers Unet for Neural Engine.

in the end, this is the technique I've had to resort to to replace every AttentionBlock with CrossAttention (after model loading):
https://github.com/Birch-san/diffusers-play/commit/bf9b13e6e6861af0d300584a5c7c0a9ec3d79a28
you may find this as a useful reference for how to map between them.

Birch-san avatar Jan 03 '23 00:01 Birch-san

@Birch-san Thank you for the added context, super helpful! I don't have much to add right now. When I start working on the refactor, I'll think about it more and we can discuss :)

williamberman avatar Jan 04 '23 18:01 williamberman

It seemed the two class might have some slight difference. I noticed group_norm missing from a few of the processor implementations for the class CrossAttention, which can have group_norm or without,

CrossAttnProcessor doesn't use group_norm, SlicedAttnAddedKVProcessor and CrossAttnAddedKVProcessor uses group_norm (without checking if it is actually none, which CrossAttention allows).

Where as AttentionBlock in attention.py always uses group_norm.

Lime-Cakes avatar Jan 12 '23 08:01 Lime-Cakes

tl;dr

  1. Does the api design of the attention processors prohibit anything existing in CrossAttention.__call__ or should the entirety of the method exist in the processor so all of the attention mechanism is hackable. Example being residual connections.
  2. What adhoc configuration should we allow for attention processors? What about when it results in diverging defaults. I.e. making residual connections configurable now would have different defaults for different processors.
  3. While porting AttentionBlock to CrossAttention, should we make CrossAttnProcessor configurable or make a new processor that is mostly the same as CrossAttnProcessor with a residual connection? (there are potentially other changes that might need to be made, I haven't looked through all of it yet).

longer message:

QQ re api design in the attention processor: If we were to configure whether or not there is a residual connection, in an ideal world, would this occur in the processor class or in CrossAttention. Currently we just pass everything from call into the processor. Some processors have a residual connection, some do not. Because the attention block is the same dims in as dims out, the residual connection would be the same regardless of what happens in the processor and the residual connection could be done within CrossAttention.call.

However, doing the residual connection w/in CrossAttention.call means not the entirety of the attention application is hackable.

Context is that AttentionBlock has a residual connection and CrossAttnProcessor does not so I'm leaning towards adding a config to CrossAttnProcessor's constructor for it to perform a residual connection. However, if we were to make the residual connection in other processors which already have residual connections configurable, they would have different default values.

Alternatively, we could make a separate attention processor for the currently deprecated AttentionBlock that would be mostly copy-paste from the existing CrossAttnProcessor.

williamberman avatar Jan 26 '23 18:01 williamberman

Follow up: residual connections would stay in the processor regardless because it isn't guaranteed for the residual connection to be the last step in the method. I.e. in AttentionBlock, we rescale the output after the residual connection is applied.

IMO, this means that regardless of commonalities, the entirety of the attention application should occur in the processor. Anything that we assume to be common to all processors is potentially a point for breakage and might need bad hacks to make work on future attention processors.

However the other questions around configuration w/ defaults still stand.

williamberman avatar Jan 26 '23 18:01 williamberman

tl;dr from offline convo:

The existing CrossAttnProcessor provides attention over inputs of size (batch_size, seq_len, hidden_size) and the AttentionBlock we're deprecating provides attention over spatial inputs of (batch_size, channels, height, width) so we'll make a separate class called SpatialAttnProcessor.

For now, we'll only add self attention to the new attention processor and we can add in cross attention later. Note that we'll also change the name of CrossAttnProcessor to just AttnProcessor so the standard naming will be consistent regardless of the type of attention applied (these are internal/private classes, so changing the names should be acceptable).

We did not discuss what will happen in the future if we have to add configuration to different attention processors and that results in different default configs (i.e. the residual connection example earlier). Let's assume this is ok to not discuss for now especially as these are private classes and we'll have more flexibility if we have to make changes to them.

williamberman avatar Jan 27 '23 18:01 williamberman

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

github-actions[bot] avatar Feb 21 '23 15:02 github-actions[bot]

This is still very much relevant cc @williamberman

patrickvonplaten avatar Mar 07 '23 09:03 patrickvonplaten

We're getting too many issues / PRs about confused users. Let's try to make this high prio @williamberman

patrickvonplaten avatar Mar 13 '23 15:03 patrickvonplaten

To begin with let's start by doing the following:

1.) Rename all processors that are called CrossAttention to just Attention 2.) Rename the file cross_attention.py to attention_processor.py,

Note: We need to keep full backwards compatibilty: We import all classes from attention_processor.py into cross_attention.py and raise a deprecation warning whenever someone imports from cross_processor.py. If classes are renamed in attention_processor.py we should import them as follows:

from .attention_processor import AttentionProcessor as CrossAttentionProcessor

patrickvonplaten avatar Mar 13 '23 15:03 patrickvonplaten

Once that's done, let's fully continue by removing the old AttentionBlock class.

patrickvonplaten avatar Mar 13 '23 15:03 patrickvonplaten

Would any change effect old model (e.g, renaming state dict key). Seems like changing/removing AttentionBlock wouldn't change pytorch state dict, as key name isn't based on class name. I'm unsure if other format such as flax and safetensor would require change though.

Lime-Cakes avatar Mar 13 '23 18:03 Lime-Cakes

Start refactor here: https://github.com/huggingface/diffusers/pull/2691#issue-1625959919

patrickvonplaten avatar Mar 15 '23 18:03 patrickvonplaten

@patrickvonplaten Is the refactoring done? I'm using a code base built on diffusers0.11, if so I can start reformatting my code.

haofanwang avatar Mar 21 '23 07:03 haofanwang

#2697 will be merged very soon!

patrickvonplaten avatar Mar 21 '23 14:03 patrickvonplaten

https://github.com/huggingface/diffusers/pull/3387 done here

williamberman avatar May 31 '23 17:05 williamberman