diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

Official callbacks

Open asomoza opened this issue 10 months ago • 38 comments

What does this PR do?

Initial draft to support for official callbacks.

This is the most basic implementation I could think of without the need of modifying the pipelines.

After this, we need to discuss if we're going modify the pipelines to support additional functionalities:

On step begin

For this issue for example, the propossal is to start the CFG after a certain step and to stop it after another step. For the CFG on begin we would need to add an additional callback on_step_begin if we want to do it on the callbacks instead of manually doing it with the embeds and pass them to the pipelines. The same will be needed for differential diffusion.

Automatic callback_on_step_end_tensor_inputs

With the current implementation the user needs to know what to add to the callback_on_step_end_tensor_inputs list, for example for the SDXL implementation of the CFG cutout we need to add prompt_embeds, add_text_embeds, add_time_ids or it won't work. If we want to do this automatically I'll need to modify the pipelines, if not, I can add a error message indicating what values are missing.

The user already needs to know the args for each callback so maybe this is better to just document in a README for all the callbacks.

Chain callbacks

Should we add the functionality to chain callbacks? for example to use a list of callbacks, so we can use the CFG and IP cutouts at the same time? The alternative is to create another callback that does both of them.

  • [X] Draft
  • [X] Automatic callback_on_step_end_tensor_inputs
  • [X] Chain callbacks
  • [ ] On step begin?
  • [ ] clean code
  • [ ] review

Fixes #7736

Example usage:

# for SD 1.5
from diffusers.callbacks import SDCFGCutoffCallback

callback = SDCFGCutoutCallback(cutoff_step_ratio=0.4)
# can also be used with cutoff_step_index
# callback = SDCFGCutoffCallback(cutoff_step_ratio=None, cutoff_step_index=10)

image = pipe(
    prompt=prompt,
    negative_prompt="",
    guidance_scale=6.5,
    num_inference_steps=25,
    generator=generator,
    callback_on_step_end=callback,
).images[0]
0.2 0.4 0.8 1.0
20240424003537_4009094394 20240424003545_4009094394 20240424003633_4009094394 20240424003643_4009094394
# for SDXL
from diffusers.callbacks import SDXLCFGCutoffCallback

callback = SDXLCFGCutoffCallback(cutoff_step_ratio=0.4)

image = pipe(
    prompt=prompt,
    negative_prompt="",
    guidance_scale=6.5,
    num_inference_steps=25,
    generator=generator,
    callback_on_step_end=callback,
).images[0]
0.2 0.4 0.8 1.0
20240424003846_4009094394 20240424003859_4009094394 20240424003914_4009094394 20240424003929_4009094394
# IP Adapter cutout
from diffusers.callbacks import IPAdapterScaleCutoffCallback

callback = IPAdapterScaleCutoffCallback(cutoff_step_ratio=0.4)

image = pipe(
    prompt=prompt,
    negative_prompt="",
    guidance_scale=6.5,
    num_inference_steps=25,
    generator=generator,
    ip_adapter_image=ip_image,
    callback_on_step_end=callback,
).images[0]
IP Image 0.3 0.5 1.0
ip_source 20240424004314_2010138750 20240424004328_2010138750 20240424004343_2010138750
# Callback list
from diffusers.callbacks import IPAdapterScaleCutoffCallback, MultiPipelineCallbacks, SDXLCFGCutoffCallback

ip_callback = IPAdapterScaleCutoffCallback(cutoff_step_ratio=0.5)
cfg_callback = SDXLCFGCutoffCallback(cutoff_step_ratio=None, cutoff_step_index=10)

callbacks = MultiPipelineCallbacks([ip_callback, cfg_callback])

image = pipe(
    prompt=prompt,
    negative_prompt="",
    guidance_scale=6.5,
    num_inference_steps=25,
    generator=generator,
    ip_adapter_image=ip_image,
    callback_on_step_end=callbacks,
).images[0]
IP:1.0 - CFG 1.0 IP:1.0 - CFG 0.4 IP 0.5 - CFG 0.4
20240429200914_2010138750 20240429200753_2010138750 20240429201007_2010138750

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

asomoza avatar Apr 24 '24 05:04 asomoza

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

cc @a-r-r-o-w here too

yiyixuxu avatar Apr 24 '24 06:04 yiyixuxu

I like this - very nice and simple

for the next steps, let's see a proposal for this?

Automatic callback_on_step_end_tensor_input

after that, we can play around with on_the_step_begin; I think we don't have to consider chain callbacks for now, but open to it if we see use many cases for it in the future

yiyixuxu avatar Apr 24 '24 16:04 yiyixuxu

actually would be nice to support list of callbacks since now we provide official ones that user can mix and match

yiyixuxu avatar Apr 24 '24 17:04 yiyixuxu

actually would be nice to support list of callbacks since now we provide official ones that user can mix and match

Yeah, I think this is the right way to do it. In fact I would say to not even use "callbacks" but rather just a pure function for doing each sampling step called default_sampling_function, which is a pure function into which we pass all things possibly required for a single sampling step.

Basically we have the sampling loop (SD pipeline as an example)

        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                if self.interrupt:
                    continue

                # expand the latents if we are doing classifier free guidance
                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input,
                    t,
                    encoder_hidden_states=prompt_embeds,
                    timestep_cond=timestep_cond,
                    cross_attention_kwargs=self.cross_attention_kwargs,
                    added_cond_kwargs=added_cond_kwargs,
                    return_dict=False,
                )[0]

                # perform guidance
                if self.do_classifier_free_guidance:
                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

                # compute the previous noisy sample x_t -> x_t-1
                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

                if callback_on_step_end is not None:
                    callback_kwargs = {}
                    for k in callback_on_step_end_tensor_inputs:
                        callback_kwargs[k] = locals()[k]
                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

                    latents = callback_outputs.pop("latents", latents)
                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

                # call the callback, if provided
                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                    progress_bar.update()
                    if callback is not None and i % callback_steps == 0:
                        step_idx = i // getattr(self.scheduler, "order", 1)
                        callback(step_idx, t, latents)

In my opinion, we need should replace everything underneath for i, t in enumerate(timesteps): with a default_sampling_function function defined elsewhere in the pipeline. The default_sampling_function should do the normal sampling loop and take a SamplingInput dataclass consisting of the things required to perform the sampling step, then a SamplingOutput which consists of the things that are normally returned from the sampling step. The SamplingOutput can be re-fed in as a SamplingInput with the next timestep into the default_sampling_function after.

We can add an argument sampling_functions: list[Callable]=[default_sampling_function] into the __call__ as a new, backwards compatible kwarg.

In this way we can finally get complete control over the sampling loop and chain multiple functions together the process the output of the sampling loop in the order of the sampling functions.

the specifics of that i'll have to wrap my head around but the initial idea of decoupling the logic inside __call__ so it can be more effectively monkeypatched downstream sounds ideal. both things can be done, really

bghira avatar Apr 26 '24 21:04 bghira

there's the concept library on the hf hub from back in the day. for the uninitiated, it is/was a collection of dreambooths others had done, to make it easier to find eg. a backpack checkpoint or some other oddly specific item you reliably needed to work.

i know it's a security nightmare, but the idea of hub-supported callbacks "calls to me" as something worth bringing up.

on the other hand, having community callbacks in this repo is time-consuming but that allows thorough review of any callbacks that are included. unlike dreambooths, callbacks seem like they'd be rarely created, whereas there a billion potential concepts for a dreambooth.

listing the available callbacks is quite trivial in either case, where a diffusers:callbacks tag or something can be used to differentiate them. scanning these for safety issues with an LLM would possibly help sanitise any obvious issues?

bghira avatar Apr 26 '24 21:04 bghira

It's probably easier if I write it out in some pseudocode. Writing it down, I think SamplingOutput is probably redundant, so maybe we could come up with a better name for that dataclass.

class SamplingInput:
    def __init__(self, img, text_embedding, unet, timestep=None, **kwargs):
        self.img = img
        self.text_embedding = text_embedding
        self.unet = unet
        self.timestep = timestep

# ... lots of other code ...

        inp = SamplingInput(img, text_embedding, unet)
        with self.progress_bar(total=num_inference_steps) as progress_bar:
            for i, t in enumerate(timesteps):
                inp.timestep = t
                for sampling_function in self.sampling_functions:
                    inp = sampling_function(inp)
                    
        output_img = inp.img

Then we can do something like sampling_functions=[default_sampling_function, report_image_on_step], where the second function is just a sampling function that does nothing with the input, ships the in-progress image to an API somewhere to do realtime updates of inference, then returns the original input.

I really like the array of function pointers. It makes composition easy and clearly signals that the methods are designed to be changeable.

Beinsezii avatar Apr 27 '24 00:04 Beinsezii

ohh thanks! for this PR we will keep it simple and support official callbacks with minimum change to our pipelines what you proposed will introduce a pretty drastic change to our design and I think it is outside the scope of this PR so maybe it is better to open a new discussion https://github.com/huggingface/diffusers/discussions instead?

yiyixuxu avatar Apr 28 '24 04:04 yiyixuxu

a bit late to the party here, but adding one use-case: modifying or skipping steps. right now, loop is fixed and no matter what happens in callbacks, they cannot influence it:

            for i, t in enumerate(timesteps):

big use case is for callback to actually modify timesteps in some sense - perhaps we want to skip a step? perhaps force an early end since callback function determined it got what it needed and there is no point of running all the remaining steps to completion?

vladmandic avatar Apr 28 '24 04:04 vladmandic

a bit late to the party here, but adding one use-case: modifying or skipping steps. right now, loop is fixed and no matter what happens in callbacks, they cannot influence it:

            for i, t in enumerate(timesteps):

big use case is for callback to actually modify timesteps in some sense - perhaps we want to skip a step? perhaps force an early end since callback function determined it got what it needed and there is no point of running all the remaining steps to completion?

The scheduler modifies which timesteps are in the timestep list, so determination of timesteps to run lives there. You can very simply just write your own scheduler to exclude some timesteps.

or i guess a scheduler wrapper that takes in its own callbacks, in teh case of SD.Next

bghira avatar Apr 28 '24 13:04 bghira

ohh thanks! for this PR we will keep it simple and support official callbacks with minimum change to our pipelines what you proposed will introduce a pretty drastic change to our design and I think it is outside the scope of this PR so maybe it is better to open a new discussion https://github.com/huggingface/diffusers/discussions instead?

It's a continuation of #7736 but engineering a proper solution rather than a half baked one will save longer in the long run.

For example right now for determining timesteps we have schedulers -- the scheduler is a effectively a function you can pass into the pipeline that is relatively pure and just gets which timesteps you are supposed to perform, for the most part. Ideally we extend such functional designs to the sampling loop as well, and in this case, extend the ability to run multiple sampling functions in sequence.

I believe this solves every current and previous deficit that hacks like

        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],

have failed to fully address.

In my opinion this is poorly engineered, and now that it exists in the codebase it will need to be supported with backwards compatibility for the rest of time whereas I believe my proposed solution is (1) clean (2) consistent with the engineering of schedulers and (3) will not result in technical debt, but will incur a large one time cost to support by various pipelines. For now, just a few of the most used pipelines could be done and the rest stubbed with NotImplementedError.

The scheduler modifies which timesteps are in the timestep list, so determination of timesteps to run lives there. You can very simply just write your own scheduler to exclude some timesteps.

@AmericanPresidentJimmyCarter i get that, but i don't want to monkey-patch all schedulers existing in diffusers.

example use case - there are some experimental sd15 models popping-up that are only finetuned on high-noise or low-noise - with idea behind them very similar to sdxl-refiner, but stabilityai never did refiner for sd15 and there is no pipeline for it. so use case would be to allow initial run to "stop early" (e.g. at 80% of timesteps) so another model can continue from 80%+ of its timesteps (meaning we need to be able to set initial timestep).

or i guess a scheduler wrapper that takes in its own callbacks, in teh case of SD.Next

@bghira i might as well need to do that, i though since we're talking about callbacks design here this would be a place to address future needs.

vladmandic avatar Apr 28 '24 14:04 vladmandic

that's a good point vlad. i was just thinking a preliminary attempt at a scheduler wrapper might result in some lessons being discovered that might help make a better upstream (diffusers) design. but maybe you already have a concrete idea? :P

also #4355 for your SD 1.x refiner needs.

bghira avatar Apr 28 '24 14:04 bghira

Yeah, this would require even more re-engineering. You would need sampling functions to be a part of the scheduler, and all of them would need to be passed to the scheduler instead of the pipeline. The net effect is more or less the same.

So for every pipeline, we would have a default sampling function which we pass to the default scheduler, and we could also pass multiple of these as I proposed. Then the only difference is in the sampling loop we self.scheduler.sample(...).

or a very simple hack using existing callback concept:

  • allow modification of timesteps array from a callback
  • add something like this:
for i, t in enumerate(timesteps):
  if t <= 0:
    continue

vladmandic avatar Apr 28 '24 15:04 vladmandic

Why wouldn't you just subclass the scheduler and then overwrite the get timestep method? That seems trivial?

@vladmandic this should solve your problem, no? https://huggingface.co/docs/diffusers/v0.27.2/en/using-diffusers/callback#interrupt-the-diffusion-process

yiyixuxu avatar Apr 28 '24 18:04 yiyixuxu

@vladmandic I think it's what you proposed here, already implemented https://github.com/huggingface/diffusers/blob/56bd7e67c2e01122cc93d98f5bd114f9312a5cce/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L970

add something like this: for i, t in enumerate(timesteps): if t <= 0: continue

yiyixuxu avatar Apr 28 '24 18:04 yiyixuxu

@AmericanPresidentJimmyCarter like I said, this PR's scope is

for this PR we will keep it simple and support official callbacks with minimum change to our pipelines

feel free to open another issue or discussion

yiyixuxu avatar Apr 28 '24 19:04 yiyixuxu

Opened as #7808

Really nice comments and suggestions, I'm taking notes of them all but for the meantime I'm sticking to the original plan that's it to make the minimal changes of what we already have.

I added:

  • Automatic detection of inputs
  • List of callbacks

I'm not that happy with the method for automatically getting the inputs but it was the best I could think of, still it doesn't feel right to me but I wanted to keep the creation of the callbacks as functions and not objects to build on what we already have.

I'm open to suggestions if someone can think of a better system, the alternative is to just use objects which we'll probably need if we extend the functionality for on_step_begin so that a single object can have two or more callbacks.

asomoza avatar Apr 30 '24 00:04 asomoza

I think chaining the callbacks should be handled outside the pipeline.

what do you think about this design?

first, let's define a base Callback class like this

class PipelineCallback:

    def __init__(self, **kwargs):
        self.func = partial(self.callback_fn, **kwargs)
    
    @property
    def tensor_inputs(self) -> Lis:
        raise NotImplementedError(...)

    def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs, **kwargs):
        pass 
    
    def __call__(self, pipeline, step_index, timestep, callback_kwargs):
        return self.func(pipeline, step_index, timestep, callback_kwargs)

the official callbacks should all inherit from this base class, e.g.

class IPAdapterScaleCutoutCallback(PipelineCallback):
    tensor_inputs = []
    
    def __init__(self, step_ratio=1):
        super().__init__(step_ratio=step_ratio)
    
    def callback_fn(self, pipeline, step_index, timestep, callback_kwargs, step_ratio=1):
        if step_index == int(pipeline.num_timesteps * step_ratio):
            pipeline.set_ip_adapter_scale(0.0)
        return callback_kwargs

user can use this callback like this

ip_callback = IPAdapterScaleCutoutCallback(step_ratio=0.5)

for a list of callbacks, we should define another class, e.g. MultiPipelineCallbacks, which:

  1. take a list of callbacks in its __init__
  2. it also has a tensor_input property that automatically combines all the callbacks' tensor_input into a list
  3. it has a __call__ function that will call the callbacks in sequential order

user can use a list of callbacks like this

callbacks = MultiPipelineCallbacks(ip_callback, cfg_callback)

with this design, the change we need to make to the pipeline should be minimal, just need to define callback_on_step_end_tensor_inputs

yiyixuxu avatar May 01 '24 04:05 yiyixuxu

what do you think about this design?

I like that design a lot more, partially that's what I meant when I said to use objects instead of functions and the reason of my doubts, I like a lot more to use classes and inheritance instead of just functions but I wasn't sure if we should change the callbacks to classes instead of functions when creating them.

I'll implement it like this and it helps a lot to learn on how to do future PRs, thanks.

asomoza avatar May 01 '24 07:05 asomoza

@asomoza yes for official callbacks, since we need to define both the inputs and the function, I think class makes more sense. the user still should be able to pass a function as callbacks though,

this PR ends up being more complex than I had thought :) I thought we just needed to define a few functions users can import. But indeed it is a good practice for future PRs

yiyixuxu avatar May 01 '24 09:05 yiyixuxu

@yiyixuxu I like the idea of cutoff_step_index but we need to set a default and I don't think it's a good idea to set it to 0 but we don't know the exact value of the max step_index here, so set it to 1000?

asomoza avatar May 06 '24 23:05 asomoza

I like the idea of cutoff_step_index but we need to set a default and I don't think it's a good idea to set it to 0 but we don't know the exact value of the max step_index here, so set it to 1000?

I think you can set the index default to None and ratio default to 1: you can only either use the index or the ratio, so if the index is None and ratio is not None, we use ratio, vise versa; when both are None or both are not None, we throw an error

yiyixuxu avatar May 07 '24 05:05 yiyixuxu

@asomoza should this go into the official callback too? https://huggingface.co/docs/diffusers/v0.27.2/en/using-diffusers/callback#display-image-after-each-generation-step

yiyixuxu avatar May 07 '24 19:05 yiyixuxu