diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[Community] Implement `prompt-to-prompt` pipelines

Open apolinario opened this issue 1 year ago • 23 comments

Describe the solution you'd like Now that we have an official way to tweak cross attention https://github.com/huggingface/diffusers/pull/1639 , would be great to have a pipeline (be it official or community) for prompt-to-prompt and further implementations of the technique (such as EDICT).

Describe alternatives you've considered @amirhertz official Prompt-to-Prompt implementation is built on top of diffusers 0.3.0 with their own cross attention manipulation function. @bloc97 community prompt-to-prompt implementation already uses diffusers, but it is pinned to version 0.4.1, also with a cross attention control of their own. @bram-w / Salesforce EDICT , that adds inversion to prompt-to-prompt (allowing you to edit real images) also uses the above as a base with some modifications for double precision for inversion.

So while alternatives exist, they require users to pin old versions of diffusers and not enjoy the latest advancements. Given this technique is very useful, having it on a pipeline within diffusers could be really great. Also could potentially leverage the technique to other models (Karlo, IF, etc.)

Additional context InstructPix2Pix and Imagic have shown how editing real and edited images is a trend. Prompt-to-prompt is a nice tool to have on that belt for practitioners, artists and professionals.

apolinario avatar Jan 26 '23 15:01 apolinario

+100 - just lacking the time at the moment. I wonder whether we should do a community sprint in a week or so trying to add the most important "tweak your text prompts" pipelines.

patrickvonplaten avatar Jan 26 '23 18:01 patrickvonplaten

Actually taking this as an opportunity to turn the feature request into a more precise explanation of how it can be added.

In short we have now all the necessary tools to add a Pipeline like Prompt-2-prompt in a nice & clean way. Prompt-2-prompt is an official pipeline with paper release and 1k+ stars, so IMO we should put it in src/diffusers/pipelines

What you'll need to do:

  • Add a new pipeline Prompt2PromptPipeline that can be more or less copied from StableDiffusionPipeline
  • As @apolinario said https://github.com/huggingface/diffusers/pull/1639 now allows one to add a prompt-2-prompt specific attention processor which can also accept new arguments to the __call__ function. Those can be passed via cross_attention_kwargs={"swapped_out_embeds"=..., swapped_in_embeds=...} to: https://github.com/huggingface/diffusers/blob/7436e30c720547c97b602155b8f5690976efdbc4/src/diffusers/models/unet_2d_condition.py#L403 which will then passed to your customized prompt-2-prompt attention processor.
  • The prompt-2-prompt attention processor class should reside inside the src/pipelines/prompt2prompt/pipeline_prompt2prompt.py` script and can be loaded directly into the unet via: https://github.com/huggingface/diffusers/blob/7436e30c720547c97b602155b8f5690976efdbc4/src/diffusers/models/unet_2d_condition.py#L297 => The super cool thing about the pipeline is then that it'll be compatible with all SD checkpoints on the Hub.

Very keen on guiding someone from the community through a PR, but currently don't find the time to do it

patrickvonplaten avatar Jan 26 '23 19:01 patrickvonplaten

You may also reference InvokeAI's update for the diffusers 0.12 attention API: https://github.com/invoke-ai/InvokeAI/pull/2385

A few caveats:

  • we've only implemented the replacement (swap) control so far. Refinement and re-weight are yet to do.
  • the attention code is only a part of what's necessary for a full pipeline; you also need a way to identify which controls to apply where. Those other parts are probably easier to take from the other implementations than InvokeAI.

keturn avatar Jan 26 '23 20:01 keturn

So I had the following attention processors in mind for this variant of the prompt-to-prompt: https://github.com/cccntu/efficient-prompt-to-prompt

class CrossAttnKVProcessor:
    def __call__(
        self, attn: CrossAttention, hidden_states, key_hidden_states=None, value_hidden_state=None, attention_mask=None
    ):
        _, sequence_length, _ = hidden_states.shape
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)

        query = attn.to_q(hidden_states)
        query = attn.head_to_batch_dim(query)

        key_hidden_states = key_hidden_states if key_hidden_states is not None else hidden_states
        value_hidden_state = value_hidden_state if value_hidden_state is not None else hidden_states
        key = attn.to_k(key_hidden_states)
        value = attn.to_v(value_hidden_state)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        return hidden_states

class XFormersCrossAttnKVProcessor:
    def __call__(
        self, attn: CrossAttention, hidden_states, key_hidden_states=None, value_hidden_state=None, attention_mask=None
    ):
        _, sequence_length, _ = hidden_states.shape

        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)

        query = attn.to_q(hidden_states)

        key_hidden_states = key_hidden_states if key_hidden_states is not None else hidden_states
        value_hidden_state = value_hidden_state if value_hidden_state is not None else hidden_states
        key = attn.to_k(key_hidden_states)
        value = attn.to_v(value_hidden_state)

        query = attn.head_to_batch_dim(query).contiguous()
        key = attn.head_to_batch_dim(key).contiguous()
        value = attn.head_to_batch_dim(value).contiguous()

        hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
        hidden_states = hidden_states.to(query.dtype)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        return hidden_states


class SlicedAttnKVProcessor:
    def __init__(self, slice_size):
        self.slice_size = slice_size

    def __call__(
        self, attn: CrossAttention, hidden_states, key_hidden_states=None, value_hidden_state=None, attention_mask=None
    ):
        _, sequence_length, _ = hidden_states.shape

        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)

        query = attn.to_q(hidden_states)
        dim = query.shape[-1]
        query = attn.head_to_batch_dim(query)

        key_hidden_states = key_hidden_states if key_hidden_states is not None else hidden_states
        value_hidden_state = value_hidden_state if value_hidden_state is not None else hidden_states
        key = attn.to_k(key_hidden_states)
        value = attn.to_v(value_hidden_state)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        batch_size_attention = query.shape[0]
        hidden_states = torch.zeros(
            (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
        )

        for i in range(hidden_states.shape[0] // self.slice_size):
            start_idx = i * self.slice_size
            end_idx = (i + 1) * self.slice_size

            query_slice = query[start_idx:end_idx]
            key_slice = key[start_idx:end_idx]
            attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None

            attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)

            attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])

            hidden_states[start_idx:end_idx] = attn_slice

        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        return hidden_states

kashif avatar Jan 26 '23 20:01 kashif

        self, attn: CrossAttention, hidden_states, key_hidden_states=None, value_hidd

Sure, this seems reasonable, guess would be great to see it in a pipeline class directly :-)

patrickvonplaten avatar Jan 26 '23 21:01 patrickvonplaten

Is this open? Would be happy to take it up!

unography avatar Jan 27 '23 06:01 unography

@unography yes it's open, please feel free to contribute!

kashif avatar Jan 27 '23 07:01 kashif

@kashif sure, will add a draft PR soon

unography avatar Jan 27 '23 07:01 unography

This looks plausible thanks! Furthermore, with the xformers implementation, how can we retrieve softmaxed k*q attention map (before applying to values)? See here: https://github.com/facebookresearch/xformers/blob/5df1f0b682a5b246577f0cf40dd3b15c1a04ce50/xformers/ops/fmha/init.py#L149

class XFormersCrossAttnKVProcessor:
    def __call__(
        self, attn: CrossAttention, hidden_states, key_hidden_states=None, value_hidden_state=None, attention_mask=None
    ):
        _, sequence_length, _ = hidden_states.shape

        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)

        query = attn.to_q(hidden_states)

        key_hidden_states = key_hidden_states if key_hidden_states is not None else hidden_states
        value_hidden_state = value_hidden_state if value_hidden_state is not None else hidden_states
        key = attn.to_k(key_hidden_states)
        value = attn.to_v(value_hidden_state)

        query = attn.head_to_batch_dim(query).contiguous()
        key = attn.head_to_batch_dim(key).contiguous()
        value = attn.head_to_batch_dim(value).contiguous()

        hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
        hidden_states = hidden_states.to(query.dtype)
        hidden_states = attn.batch_to_head_dim(hidden_states)

evinpinar avatar Feb 01 '23 15:02 evinpinar

Taking a step back -- I question the actual usefulness of "prompt-to-prompt". Why would someone generate an image with the wrong prompt in the first place?? If I wanted a "box of cookies", why did I type "box of apples"?

Plus, there are now more powerful and flexible techniques available. The paper below requires no input prompt, just a raw image, from which it extracts various features from the diffusion layers and applies them to a new prompt. This seems much more in line with a normal image workflow than prompt-to-prompt. Cheers.

https://arxiv.org/pdf/2211.12572.pdf

Alchete avatar Feb 01 '23 17:02 Alchete

If useful for anyone, I've implemented an Attend-to-Excite with the AttentionProcessors, an example is here: https://github.com/evinpinar/Attend-and-Excite-diffusers/blob/72fa567a1e3bb3cc1b63cb53a1d9db5fc10b241f/utils/ptp_utils.py#L57


class AttendExciteCrossAttnProcessor:

    def __init__(self, attnstore, place_in_unet):
        super().__init__()
        self.attnstore = attnstore
        self.place_in_unet = place_in_unet

    def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
        batch_size, sequence_length, _ = hidden_states.shape
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)

        query = attn.to_q(hidden_states)

        is_cross = encoder_hidden_states is not None
        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)

        self.attnstore(attention_probs, is_cross, self.place_in_unet)

        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        return hidden_states


def register_attention_control(model, controller):

    attn_procs = {}
    cross_att_count = 0
    for name in model.unet.attn_processors.keys():
        cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim
        if name.startswith("mid_block"):
            hidden_size = model.unet.config.block_out_channels[-1]
            place_in_unet = "mid"
        elif name.startswith("up_blocks"):
            block_id = int(name[len("up_blocks.")])
            hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]
            place_in_unet = "up"
        elif name.startswith("down_blocks"):
            block_id = int(name[len("down_blocks.")])
            hidden_size = model.unet.config.block_out_channels[block_id]
            place_in_unet = "down"
        else:
            continue

        cross_att_count += 1
        attn_procs[name] = AttendExciteCrossAttnProcessor(
            attnstore=controller, place_in_unet=place_in_unet
        )

    model.unet.set_attn_processor(attn_procs)
    controller.num_att_layers = cross_att_count

evinpinar avatar Feb 03 '23 12:02 evinpinar

Super cool! @evinpinar feel free to open a PR to add this as a new pipeline. Maybe this PR is a good example of how to add a new simple pipeline: https://github.com/huggingface/diffusers/pull/2223

Amazing work :heart:

patrickvonplaten avatar Feb 03 '23 17:02 patrickvonplaten

@evinpinar Looks awesome!

isamu-isozaki avatar Feb 06 '23 01:02 isamu-isozaki

Btw, is there a pr like this for the prompt to prompt? I just want to check out the implementation for research. If not happy to make one based on @evinpinar code

isamu-isozaki avatar Feb 06 '23 01:02 isamu-isozaki

Hi, everyone, I just implement a pipeline here. https://github.com/Weifeng-Chen/prompt2prompt base on @evinpinar code. Borrow from google's prompt-to-prompt. using a 'controller' to replace, refine, or reweight. the controller now is outside the pipelin, don't know whether to put it to the pipeline. here's a reference code now, any advice for the api?

from pipeline_prompt2prompt import Prompt2PromptPipeline
from ptp_utils import AttentionStore, AttentionReplace, LocalBlend, AttentionRefine, AttentionReweight, view_images, get_equalizer
import torch
import numpy as np

g_cpu = torch.Generator().manual_seed(2333)
device = "cuda"

pipe = Prompt2PromptPipeline.from_pretrained("CompVis/stable-diffusion-v1-4" ).to(device)

prompts = ["A painting of a squirrel eating a burger",
           "A painting of a cat eating a burger"]

NUM_DIFFUSION_STEPS = 20
lb = LocalBlend(prompts, ("squirrel", "cat"), tokenizer=pipe.tokenizer, device=pipe.device)
controller = AttentionReplace(prompts, NUM_DIFFUSION_STEPS, cross_replace_steps=.4, self_replace_steps=0.4, tokenizer=pipe.tokenizer, device=pipe.device, local_blend=lb)
outputs = pipe(prompt=prompts, height=512, width=512, num_inference_steps=NUM_DIFFUSION_STEPS,
                controller=controller, generator=g_cpu,)
view_images( [np.array(img) for img in outputs.images] )

pipe.show_cross_attention(prompts, controller, res=16, from_where=("up", "down"), select=0)
pipe.show_cross_attention(prompts, controller, res=16, from_where=("up", "down"), select=1)

for more operation, have a look at https://github.com/Weifeng-Chen/prompt2prompt/blob/main/p2p_test.ipynb

Weifeng-Chen avatar Feb 07 '23 10:02 Weifeng-Chen

@Weifeng-Chen Thanks and awesome!

isamu-isozaki avatar Feb 07 '23 13:02 isamu-isozaki

Thanks @Weifeng-Chen

I have a dumb question: when doing a refinement, what does self_replace_steps and cross_replace_steps mean actually?

Say I want to switch between two prompts: "A painting of a squirrel eating a burger" and "A real photo of a squirrel eating a burger" at 0.7. What values do I set to these two arguments in AttentionReplace(?

asadm avatar Feb 09 '23 02:02 asadm

Thanks @Weifeng-Chen

I have a dumb question: when doing a refinement, what does self_replace_steps and cross_replace_steps mean actually?

Say I want to switch between two prompts: "A painting of a squirrel eating a burger" and "A real photo of a squirrel eating a burger" at 0.7. What values do I set to these two arguments in AttentionReplace(?

you can try: cross_replace_steps=0., self_replace_steps=0. means no replacement and totally generate a new image from scratch. I think, when inference, the new prompt will generate new cross-attn and self-attn maps, and replace it with the origin one. larger steps can let it more similar to the origin one but may restrict the editing. I didn't fully test it and point me out if I'm wrong.

Weifeng-Chen avatar Feb 09 '23 02:02 Weifeng-Chen

But then where does 0.7 go? 🤔

asadm avatar Feb 09 '23 02:02 asadm

But then where does 0.7 go? 🤔

you can try to change it. 0.7 means the first 70% steps using the origin prompt's attention and the rest 30% use the new one.

Weifeng-Chen avatar Feb 09 '23 02:02 Weifeng-Chen

Yes that's what I am trying to achieve. So does that mean I set both to .7? cross_replace_steps=0.7, self_replace_steps=0.7

Thank you!

asadm avatar Feb 09 '23 03:02 asadm

Yes that's what I am trying to achieve. So does that mean I set both to .7? cross_replace_steps=0.7, self_replace_steps=0.7

Thank you!

not necessary to be same. self-attn don't interact with the text.

Weifeng-Chen avatar Feb 09 '23 03:02 Weifeng-Chen

Hi, I'd like to take on the EDICT implementation, if someone hasn't started it.

Joqsan avatar Feb 16 '23 21:02 Joqsan

Any updates in this thread? Looking forward to it!

xvjiarui avatar Mar 01 '23 13:03 xvjiarui

Note that we've already added the pix2pix0 pipeline which is an improved version of prompt2prompt: https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/pix2pix_zero

I'm not sure how much sense prompt2prompt makes given that an improved version has already been added.

patrickvonplaten avatar Mar 02 '23 18:03 patrickvonplaten

I'm not necessarily pushing for it, but I will say that what methods like Prompt-to-Prompt and EDICT have over pix2pix zero is the lack of a need to generate source and target embeddings. In the case of editing real images, pix2pix zero would require you to not only undergo inversion steps, but also generate the source and target embeddings and get their difference before you can generate new images. With (the original) Prompt-to-Prompt paper as well as EDICT, you'd only need to undergo the inversion steps before generating the final images.

ryan-caesar-ramos avatar Mar 04 '23 13:03 ryan-caesar-ramos

I agree with @ryan-caesar-ramos , I think those serve different purposes and both could be part of a toolbox on diffusers. I think we would love a community contributed PR on p2p and EDICT!

apolinario avatar Mar 05 '23 10:03 apolinario

I was planning to work on this, but ended up using the pix2pix pipeline instead.

But like @apolinario and @ryan-caesar-ramos mentioned, it would be cool to have this. I'll work on p2p this week and raise a PR

unography avatar Mar 06 '23 05:03 unography

With the release soon(tm) of p2p-video, this gets even more relevant imo: https://video-p2p.github.io

apolinario avatar Mar 10 '23 01:03 apolinario

Do in understand correctly that adding prompt-to-prompt re-weightning is not that difficult now, but it's impossible to have it and xformers together, since we need to modify self-attention and xformers doesn't explicitly expose it?

bonlime avatar Mar 10 '23 14:03 bonlime