diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Refactor cross attention and allow mechanism to tweak cross attention function

Open patrickvonplaten opened this issue 1 year ago • 11 comments

We are thinking about how to best support methods that tweak the cross attention computation, such as hyper networks (where linear layers that map k-> k' and v-> v' are trained), prompt-to-prompt, and other customized cross attention mechanisms.

Supporting such features poses a challenge in that we need to allow the user to "hack" into the cross attention module which is "buried" inside the unet.

We make two assumptions here:

  1. The cross attention layer is always the connection between the conditioning (text encoder or image encoder) and the unet, so that we can limit the scope to allow users to only hack this connection. Therefore, we can expect the scope of other cross attention method to stay within an API that gets (hidden_states, context_embeddings) and returns again (hidden_states):
cross_attention: (hidden_states, context_embeddings) -> hidden_states
  1. We also assume that such hacking at inference time only makes sense if all previously trained weights stay the same and if all previously trained weights are used. This means that we don't allow to overwrite existing weights and instead just give the user access to the existing weights:
cross_attention_fn: (hidden_states, query_weight, key_weight, value_weight, context_embeddings) -> hidden_states

Therefore a nice API that is both somewhat clean and flexible is to just let the user write "CrossAttentionProcessor" classes that are by default weights less and take (query_weight, key_weight, value_weight) as an entry which you can see in this PR. I also took this new design to refactor the cross attention layer a bit and to make xformers, sliced attentention and normal attention different "processor" classes.

Now, let's image one would like to support prompt-to-prompt. In this case one should be able to do the following:

Note this is pseudo code:

from diffusers import CrossAttentionProcMixin

unet = # load unet

class P2PCrossAttentionProc:

    def __init__(self, head_size, upcast_attention, attn_maps_reweight):
        super().__init__(head_size=head_size, upcast_attention=upcast_attention)
        self.attn_maps_reweight = attn_maps_reweight

    def __call__(self, hidden_states, query_proj, key_proj, value_proj, encoder_hidden_states, modified_text_embeddings):
        batch_size, sequence_length, _ = hidden_states.shape
        query = query_proj(hidden_states)

        context = context if context is not None else hidden_states
        attention_probs = []
        original_text_embeddings = encoder_hidden_states
        for context in [original_text_embeddings, modified_text_embeddings]:
            key = key_proj(original_text_embeddings)
            value = self.value_proj(original_text_embeddings)
    
            query = self.head_to_batch_dim(query, self.head_size)
            key = self.head_to_batch_dim(key, self.head_size)
            value = self.head_to_batch_dim(value, self.head_size)
    
            attention_probs.append(self.get_attention_scores(query, key))
           
        merged_probs = self.attn_maps_reweight * torch.cat(attention_probs)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = self.batch_to_head_dim(hidden_states)
        return hidden_states

proc = P2PCrossAttentionProc(unet.config.head_size, unet.config.upcast_attention, 0.6)


unet.set_cross_attention_processor(proc)

unet(sample=sample, t=t, encoder_hidden_states=orig_text_embeddings, {
...

I think this design is flexible enough to allow for more complicated use cases and can also be used for training hyper-networks!

Another important point here is mutability. If I pass the same class instance to multiple layers, the class is mutable which might be desired. However it might also be desired to pass unmutable or difference proc classes to different cross attention layers. In this case, we could simple allow setting a list of processors object intsead of just one object.

patrickvonplaten avatar Dec 09 '22 16:12 patrickvonplaten

Very common processor classes such as prompt-to-prompt could also be natively added to diffusers in a new pipeline that defines a whole attention processor class. Prompt-2-prompt would be a good example

patrickvonplaten avatar Dec 09 '22 18:12 patrickvonplaten

Looks promising! I think separating out xformers and sliced attention like this helps readers as well as any code-tracers, as it no longer has those if-branches in the forward method.

I'll leave it to damian to comment on whether the API is sufficient for the types of applications he has in mind. But the thing I'd be wondering about is: What if I want to do my application-specific thing and also take advantage of the library's efficient implementations of attention?

Maybe that question is too vague without a concrete example. Because I know not all "application-specific things" are necessarily going to be compatible with all attention's implementation.

An illustrative example might be making visualizations of the attention maps, as we see in the Prompt to Prompt paper: squirrel-attention-maps or in StructureDiffusion:
structure-diffusion-attention-map

The visualization is not something that wants to control attention. It still wants to use the most efficient implementation available.

Disclosure: attention map visualizations are literally a feature InvokeAI wants to implement (as does https://github.com/JoaoLages/diffusers-interpret) but I think the interpretability it provides is of general interest. I'm not just trying to trick you in to writing InvokeAI-specific features as examples. 😉

keturn avatar Dec 09 '22 20:12 keturn

Therefore a nice API that is both somewhat clean and flexible is to just let the user write "CrossAttentionProcessor" classes that are by default weights less and take (query_weight, key_weight, value_weight) as an entry which you can see in this PR.

This is good! However, as keturn says:

The visualization is not something that wants to control attention. It still wants to use the most efficient implementation available.

It would indeed be best if the CustomCrossAttentionProcessor could optionally call parts of the default CrossAttentionProcessor's functionality -- which would be exposed in a modular way -- to get the benefits of whatever clever optimisations HuggingFace put in there. Of course we can always just copy/paste the existing code, but this leaves us downstream users with the burden of maintaining compatibility as Diffusers code changes upstream.

damian0815 avatar Dec 10 '22 10:12 damian0815

if all previously trained weights stay the same and if all previously trained weights are used

@patrickvonplaten one thing we are looking at using is LoRA, which trains "the residual" of the weights to apparently produce dreambooth-quality training in 3-4MB of shipped data:

... not all of the parameters need tuning: they found that often, Q, K, V, O (i.e., attention layer) of the transformer model is enough to tune. (This is also the reason why the end result is so small). This repo will follow the same idea.

from https://github.com/cloneofsimo/lora/blob/master/scripts/run_inference.ipynb, which monkey-patches the default Diffusers stable diffusion pipeline pipe: Screen Shot 2022-12-10 at 13 12 18

i can't claim to understand the math well enough to know exactly what that means, just thought i should flag it as one use-case we are looking into.

damian0815 avatar Dec 10 '22 12:12 damian0815

@anton-l @pcuenca @williamberman could you give this a review?

patrickvonplaten avatar Dec 11 '22 16:12 patrickvonplaten

Looks good! I think it's very compelling to model the existing xFormers and sliced attention optimizations as just instances of the new "cross-attention processor" class. I also think it would still be useful for some downstream applications to take advantage of them when possible. I've been thinking about how to achieve that, but unfortunately I currently don't see a clear path forward that is not overly complicated.

I'd start with this proposed solution to support a whole new type of applications, and then maybe we can later expand to "observer-only" callbacks that can be added independently of processors and can coexist with them. This way we could also support visualizers or loggers as an additional family of components.

In terms of the API, I like it in general and it looks clear to me. I wonder if it'd make sense to do the following instead of passing all the arguments to __call__:

  • Keep batch_to_head_dim, head_to_batch_dim in CrossAttention.
  • Implement get_attention_scores in CrossAttention too.
  • Pass the CrossAttention instance to __call__ as something like xattn, so the processor can use it like:
key = xattn.to_k(original_text_embeddings)
value = xattn.to_v(original_text_embeddings)
# etc
attention_probs.append(xattn.get_attention_scores(query, key))

This requires minimal changes to the current class, users don't have to look into another file for those implementations, and processors have information about other details they might potentially need from CrossAttention. The mixin is just doing stuff the CrossAttention module used to do.

pcuenca avatar Dec 11 '22 21:12 pcuenca

Love this

The mutability point is interesting and does seem like it might run into some trouble down the road but given none of the existing use cases rely on mutable processors, I would be ok with just moving forward with this as is

williamberman avatar Dec 12 '22 18:12 williamberman

Hi, this change would be super useful!

I have a question though, is it possible to tweak the attention map like the Prompt-to-Prompt paper suggests, as well as using memory efficient attention at the same time?

From what I understand, prompt-to-prompt works by multiplying the attention map with the modified token weights, but memory efficient attention changes the formula and doesn't explicitly calculate the attention map in the same way. Am I understanding it wrong? Is it possible to do both things together?

DavidePaglieri avatar Dec 13 '22 13:12 DavidePaglieri

if all previously trained weights stay the same and if all previously trained weights are used

@patrickvonplaten one thing we are looking at using is LoRA, which trains "the residual" of the weights to apparently produce dreambooth-quality training in 3-4MB of shipped data:

... not all of the parameters need tuning: they found that often, Q, K, V, O (i.e., attention layer) of the transformer model is enough to tune. (This is also the reason why the end result is so small). This repo will follow the same idea.

from https://github.com/cloneofsimo/lora/blob/master/scripts/run_inference.ipynb, which monkey-patches the default Diffusers stable diffusion pipeline pipe: Screen Shot 2022-12-10 at 13 12 18

i can't claim to understand the math well enough to know exactly what that means, just thought i should flag it as one use-case we are looking into.

That's a very good point! LoRA looks indeed very promising and it seems to adapt all linear layers of the CrossAttention module (see here: https://github.com/cloneofsimo/lora/blob/26787a09bff4ebcb08f0ad4e848b67bce4389a7a/lora_diffusion/lora.py#L177) so maybe we should allow to plug-in the whole CrossAttentionClass right away?

patrickvonplaten avatar Dec 14 '22 16:12 patrickvonplaten

Hi, this change would be super useful!

I have a question though, is it possible to tweak the attention map like the Prompt-to-Prompt paper suggests, as well as using memory efficient attention at the same time?

From what I understand, prompt-to-prompt works by multiplying the attention map with the modified token weights, but memory efficient attention changes the formula and doesn't explicitly calculate the attention map in the same way. Am I understanding it wrong? Is it possible to do both things together?

The problem with memory efficient attention here is that the whole attention operation is done by xformers highly optimized attention function which doesn't expose the internals such as the QK^T maps.

patrickvonplaten avatar Dec 14 '22 16:12 patrickvonplaten

The documentation is not available anymore as the PR was closed or merged.

I just learned that StructureDiffusion is no longer anonymous; see #878 for more on this cross-attention use case.

keturn avatar Dec 17 '22 10:12 keturn

Checked UnCLIP slow tests as well as SD1 and SD2 slow tests.

patrickvonplaten avatar Dec 20 '22 16:12 patrickvonplaten

@patrickvonplaten awesome! thanks! will use it when it is merged then!

kashif avatar Dec 20 '22 16:12 kashif

I'm pretty late to this discussion, but there is an implementation of Prompt-to-prompt which supports using xformers: https://github.com/cccntu/efficient-prompt-to-prompt tldr: it could be achieved by passing both prompts simultaneously as key / value in CrossAttention. previously it required applying a patch to diffusers, but with new design it could be replaced by writing a custom Processor class, which is better.

Another +1 to @patrickvonplaten for good PR and design

bonlime avatar Dec 23 '22 14:12 bonlime

Thanks a lot for sharing this @bonlime , cc @kashif

patil-suraj avatar Dec 23 '22 15:12 patil-suraj

thanks @bonlime yes I belive then using this technique we can make a simple Processor which takes a tuple for the encoder_hidden_states and passes it to the appropriate projection... let me try that out

kashif avatar Dec 23 '22 15:12 kashif

Just noticed one small backwards-incompatible API change here I think - before this change, enabling xformers would override slicing, so you could safely do both in either order and xformers would always be used.

Now, if you enable xformers and then slicing, slicing will take precedence, so the order of enabling becomes important.

hafriedlander avatar Dec 26 '22 00:12 hafriedlander

Good catch @hafriedlander! yeah, slicing was ignored if xformers was enabled before this, but I think silent ignoring is not good. Maybe we could log/warn when the attention method is being overridden. That said, IMO the current API is good, if call slice attention then slice attention will be used rather than silently ignoring it. cc @patrickvonplaten @pcuenca @anton-l

patil-suraj avatar Dec 26 '22 11:12 patil-suraj

I agree with @patil-suraj. We could maybe clarify the behaviour in the docstrings?

pcuenca avatar Dec 26 '22 12:12 pcuenca

another issue with regards to slice attention is that it is, I believe not possible currently to set a different Slice attention processor in a pipeline since that requires the slice_size and the set_attention_slice helpers to set the slice size default to setting the standard processors...

kashif avatar Dec 26 '22 12:12 kashif

Hi all, it's great that this PR landed in 0.12, so that we can experiment with various attention modification techniques. What is exactly the API for this feature, though? The discussion started with a proposal, but I am not sure wht the final API looks like.

Here there is an example of implementing Attend and Excite as an attention processor, but some points are still a little obscure to me, for instance:

  • what is exactly the API for these controllers? How do they get registered? I see that @evinpinar defines their own register_attention_control function
  • the example here uses model.set_attn_processor instead
  • what about the attention store that@evinpinar is using? Again, it seems that they are implementing their own, is it something provided by diffusers?

In short, even a few lines on how these classes are expected to be defined/used would go a long way!

andreaferretti avatar Feb 07 '23 09:02 andreaferretti

Hi @andreaferretti,

Here there is an example of implementing Attend and Excite as an attention processor, but some points are still a little obscure to me, for instance:

* what is exactly the API for these controllers? How do they get registered? I see that @evinpinar defines their own `register_attention_control` function

* the example [here](https://github.com/huggingface/diffusers/pull/1639/files/9d5e5ca9b18b55d864d931a4b6199c99065c5cd7#diff-44dfc935910f5504cfa2bb02e5a4313cfdc061a6131b47f93b31b7d41422fd25) uses `model.set_attn_processor` instead

In the sample I've provided for the Attend-and-Excite paper, we need to set a processor only on specific CrossAttention layers. I also use the model.set_attn_processor , within the register_attention_control. See here

* what about the attention store that@evinpinar is using? Again, it seems that they are implementing their own, is it something provided by diffusers?

Attention store is an accumulator of the probabilities at specific crossatnn layers with specific resolutions after a forward pass on UNet. The attention probabilities then get optimized, depending on the values of each other. As far as I understand, it is not possible to do this optimization without such an accumulator/optimization control. Diffusers library enable tweaking and changing the functionality of attention and access the values, yet not sure if we can achieve the optimization without this additional store api.

In short, even a few lines on how these classes are expected to be defined/used would go a long way!

I agree, it would be very useful!

evinpinar avatar Feb 07 '23 09:02 evinpinar

@andreaferretti i've successfully used the AttnProcessor api in InvokeAI - see for example SlicedSwapCrossAttnProcesser which gets used like this.

damian0815 avatar Feb 07 '23 12:02 damian0815

Hi @evinpinar, thank you for your prompt response!

So, one thing I gather from your example is that set_attn_processor accepts a dictionary mapping layer names to processors, and will use the corresponding processor on that specific layer, right? The other example here just calls

processor = AttnEasyProc(5.0)
model.set_attn_processor(processor)

which I can only assume will call the same processor on every layer. Are there any other overloads of set_attn_processor? Anyway, these two should be enough!

Another peculiarity of the API that i gather from the above example is that one can pass extra kwargs as a dictionary, like

model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample

and said kwargs will apparently be passed when calling the processor in fact the signature of __call__ there is

def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):

The last piece of the puzzle I am not sure about is to what extent one needs to reproduce the "normal" attention mechanism in processors. That is, even the simple processor in the example has the usual attention computation

def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None):
      batch_size, sequence_length, _ = hidden_states.shape
      attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)

      query = attn.to_q(hidden_states)

      encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
      key = attn.to_k(encoder_hidden_states)
      value = attn.to_v(encoder_hidden_states)

      query = attn.head_to_batch_dim(query)
      key = attn.head_to_batch_dim(key)
      value = attn.head_to_batch_dim(value)

      attention_probs = attn.get_attention_scores(query, key, attention_mask)
      hidden_states = torch.bmm(attention_probs, value)
      hidden_states = attn.batch_to_head_dim(hidden_states)

      # linear proj
      hidden_states = attn.to_out[0](hidden_states)
      # dropout
      hidden_states = attn.to_out[1](hidden_states)

      return hidden_states

So I assume that the computation taking place in these processors will replace the default attention computation, instead of, say, augment it in some way. In other words, every processor will have to copy this first and then modify the flow to achieve whatever it needs to do, instead of getting the already computed attention maps and just having to possibly modify them.

I think this is the case, but I am just asking if there are gross misunderstandings.

Excuse me if my questions seem naive, but this monkey patching of attention maps is already delicate enough, and without some documentation on the exact API it is hard to get started.

andreaferretti avatar Feb 07 '23 13:02 andreaferretti

Hi @damian0815 , thank you, it is very useful to have another example to learn from!!

andreaferretti avatar Feb 07 '23 13:02 andreaferretti

Hey @andreaferretti, Those are really good questions. It would be awesome if you could open an issue. We definitely want to improve the documentation for this. cc @patrickvonplaten

patil-suraj avatar Feb 08 '23 08:02 patil-suraj

Hi andreaferretti, the way I understood and use the processors aligns with your explanation. But I'm sure there can be many exciting ways to use the processors.

evinpinar avatar Feb 08 '23 09:02 evinpinar