diffusers
diffusers copied to clipboard
Refactor `attention.py`
attention.py
has at the moment two concurrent attention implementations which essentially do the exact same thing:
- https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/attention.py#L256 and
- https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/cross_attention.py#L30
Both https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/cross_attention.py#L30 and https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/attention.py#L256 are already used for "simple" attention - e.g. the former for Stable Diffusion and the later for the simple DDPM UNet.
We should start deprecating https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/attention.py#L256 very soon as it's not viable to keep two attention mechanisms.
Deprecating this class won't be easy as it essentially means we have to force people to re-upload their weights. Essentially every model checkpoint that made use of https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/attention.py#L256 has to eventually re-upload their weights to be kept compatible.
I would propose to do this in the following way:
-
- To begin with when https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/attention.py#L256 is called the code will convert the weights on the fly to https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/cross_attention.py#L30 with a very clear deprecation message that explains in detail how to one can save & re-upload the weights to remove the deprecation message
-
- Open a mass PR on all checkpoints that make use of https://github.com/huggingface/diffusers/blob/62608a9102e423ad0fe79f12a8ceb1710d2027b2/src/diffusers/models/attention.py#L256 (can be retrieved via the config) to convert the weights to the new format.
Happy to help with this PR. @williamberman are you maybe interested in taking this one?
@patrickvonplaten rather than having everybody save & re-upload their weights: can diffusers intercept the weights during model load and map them to different parameter names?
Apple uses PyTorch's _register_load_state_dict_pre_hook()
idiom to intercept the weights the state dict is being loaded, transform them and redirect them to be held in different Parameters:
https://github.com/apple/ml-ane-transformers/blob/da64000fa56cc85b0859bc17cb16a3d753b8304a/ane_transformers/huggingface/distilbert.py#L241
however, something about HF's technique for model-loading breaks this idiom. my model loading hooks never get invoked. they work in a CompVis repository, but not inside HF diffusers code. I think something about using importlib to load a .bin
skips it. it'd be really good if you could fix that — it's the number one thing that made it difficult for me to optimize the Diffusers Unet for Neural Engine.
in the end, this is the technique I've had to resort to to replace every AttentionBlock with CrossAttention (after model loading):
https://github.com/Birch-san/diffusers-play/commit/bf9b13e6e6861af0d300584a5c7c0a9ec3d79a28
you may find this as a useful reference for how to map between them.
@Birch-san Thank you for the added context, super helpful! I don't have much to add right now. When I start working on the refactor, I'll think about it more and we can discuss :)
It seemed the two class might have some slight difference. I noticed group_norm missing from a few of the processor implementations for the class CrossAttention, which can have group_norm or without,
CrossAttnProcessor doesn't use group_norm, SlicedAttnAddedKVProcessor and CrossAttnAddedKVProcessor uses group_norm (without checking if it is actually none, which CrossAttention allows).
Where as AttentionBlock in attention.py always uses group_norm.
tl;dr
- Does the api design of the attention processors prohibit anything existing in
CrossAttention.__call__
or should the entirety of the method exist in the processor so all of the attention mechanism is hackable. Example being residual connections. - What adhoc configuration should we allow for attention processors? What about when it results in diverging defaults. I.e. making residual connections configurable now would have different defaults for different processors.
- While porting
AttentionBlock
toCrossAttention
, should we makeCrossAttnProcessor
configurable or make a new processor that is mostly the same asCrossAttnProcessor
with a residual connection? (there are potentially other changes that might need to be made, I haven't looked through all of it yet).
longer message:
QQ re api design in the attention processor: If we were to configure whether or not there is a residual connection, in an ideal world, would this occur in the processor class or in CrossAttention
. Currently we just pass everything from call into the processor. Some processors have a residual connection, some do not. Because the attention block is the same dims in as dims out, the residual connection would be the same regardless of what happens in the processor and the residual connection could be done within CrossAttention.call
.
However, doing the residual connection w/in CrossAttention.call
means not the entirety of the attention application is hackable.
Context is that AttentionBlock has a residual connection and CrossAttnProcessor
does not so I'm leaning towards adding a config to CrossAttnProcessor
's constructor for it to perform a residual connection. However, if we were to make the residual connection in other processors which already have residual connections configurable, they would have different default values.
Alternatively, we could make a separate attention processor for the currently deprecated AttentionBlock that would be mostly copy-paste from the existing CrossAttnProcessor
.
Follow up: residual connections would stay in the processor regardless because it isn't guaranteed for the residual connection to be the last step in the method. I.e. in AttentionBlock
, we rescale the output after the residual connection is applied.
IMO, this means that regardless of commonalities, the entirety of the attention application should occur in the processor. Anything that we assume to be common to all processors is potentially a point for breakage and might need bad hacks to make work on future attention processors.
However the other questions around configuration w/ defaults still stand.
tl;dr from offline convo:
The existing CrossAttnProcessor
provides attention over inputs of size (batch_size, seq_len, hidden_size) and the AttentionBlock
we're deprecating provides attention over spatial inputs of (batch_size, channels, height, width) so we'll make a separate class called SpatialAttnProcessor
.
For now, we'll only add self attention to the new attention processor and we can add in cross attention later. Note that we'll also change the name of CrossAttnProcessor to just AttnProcessor so the standard naming will be consistent regardless of the type of attention applied (these are internal/private classes, so changing the names should be acceptable).
We did not discuss what will happen in the future if we have to add configuration to different attention processors and that results in different default configs (i.e. the residual connection example earlier). Let's assume this is ok to not discuss for now especially as these are private classes and we'll have more flexibility if we have to make changes to them.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
This is still very much relevant cc @williamberman
We're getting too many issues / PRs about confused users. Let's try to make this high prio @williamberman
To begin with let's start by doing the following:
1.) Rename all processors that are called CrossAttention
to just Attention
2.) Rename the file cross_attention.py
to attention_processor.py
,
Note: We need to keep full backwards compatibilty: We import all classes from attention_processor.py
into cross_attention.py
and raise a deprecation warning whenever someone imports from cross_processor.py
. If classes are renamed in attention_processor.py
we should import them as follows:
from .attention_processor import AttentionProcessor as CrossAttentionProcessor
Once that's done, let's fully continue by removing the old AttentionBlock
class.
Would any change effect old model (e.g, renaming state dict key). Seems like changing/removing AttentionBlock
wouldn't change pytorch state dict, as key name isn't based on class name. I'm unsure if other format such as flax and safetensor would require change though.
Start refactor here: https://github.com/huggingface/diffusers/pull/2691#issue-1625959919
@patrickvonplaten Is the refactoring done? I'm using a code base built on diffusers0.11, if so I can start reformatting my code.
#2697 will be merged very soon!
https://github.com/huggingface/diffusers/pull/3387 done here