dpm-solver
dpm-solver copied to clipboard
Possible to support img_callback and alternating prompts?
Hi, thanks for making your code work with Stable Diffusion.
I have a couple of requests if possible.
- Could you support the img_callback parameter? It seems to work okay in my limited testing. I'm using the version of your code which has been included with Stable Diffusion 2.0, but just backported to 1.x (it works fine with no changes).
Supporting the img_callback function allows us to render the current state every N frames. I'm using multistep and found adding code like this after multistep_dpm_solver_update works.
x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
solver_type=solver_type)
if img_callback: img_callback(x, step)
- Could you support the ability to alternate/cycle through prompts? This one seems a bit trickier and I haven't been able to code it myself yet, but was able to do it for DDIM.
For example, I may have three seperate prompts, A, B, C
On Step 1, A is used On Step 2, B is used On Step 3, C is used On Step 4, A is used .... and so on, cycling through the list of prompts/conds
This is one of the workarounds people use to get around the 75/77 token limit for both cond and ucond.
This is what my code looks like in the ddim sampler, as you can see I essentially do x_cond = cond[step % len(cond)] to find which cond in the list to use on each step.
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
cond_idx = i % len(cond)
neg_cond_idx = i % len(unconditional_conditioning)
x_cond = cond[cond_idx]
x_uc = unconditional_conditioning[neg_cond_idx]
outs = self.p_sample_ddim(img, x_cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=x_uc)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
Thanks
Hi @zwishenzug , Thank you for the greatly valuable suggestions!
- The
img_callbackargument:
Thank you for the suggestion. This argument is important for some downstream tasks such as image inpainting (e.g., adding masks at each step). I've supported it in the newest version of DPM-Solver and provided an example code with stable-diffusion.
The corresponding argument is correcting_xt_fn because it can be understood as correcting the sampled xt at time t. You can find a detailed example code for image inpainting by stable-diffusion with DPM-Solver at this script.
- The cycling prompts:
This feature is strange to me because it actually changes the diffusion ODE at each step. Could you please give me some examples/motivations for why we need it? Thank you!
Thanks for responding.
- img_callback
Thanks, I understand now, and can see how I can use correcting_xt_fn
At the beginning of DPMSolverSampler::sample I can add the following code:
def cb(x, t, step):
if img_callback: img_callback(x, step)
return x
correcting_xt_fn = cb
This is a much less intrusive way to resolve my issue while remaining compatible with the DDIM/PLMS samplers
- Cycling prompts
These examples from the AUTOMATIC1111 user interface should give some insight to what people are doing with this.
https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-editing https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alternating-words
It's correct that the target is being changed on each step. Sometimes this can have some interesting effects. For example, if you change between "Photo of Barack Obama" and "Photo of Donald Trump" on alternating steps, it may be that you get a result which is somewhat a mix of the two people.
Or in the other example, if you start off with "a male military officer" for the first half of the process, then switch to "a female military officer" half way through, you may get a more masculine woman as a result.
I've been able to get this working for my local version, I'm only concerned with classifier-free and multistep so it wasn't too hard it seems.
It seems to be a case of modifying model_fn to become model_fn(x, t_continuous, step)
Then making the code lookup the cond/ucond in the list via step modulo
elif guidance_type == "classifier-free":
if guidance_scale == 1. or unconditional_condition[step % len(unconditional_condition)] is None:
return noise_pred_fn(x, t_continuous, cond=condition[step % len(condition)])
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t_continuous] * 2)
c_in = torch.cat([unconditional_condition[step % len(unconditional_condition)], condition[step % len(condition)]])
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
return noise_uncond + guidance_scale * (noise - noise_uncond)
And making sure that the current step gets passed through from the main multistep code.
I will need to do some proper testing, but it seems to be working okay so far in my limited testing.
Hi @zwishenzug ,
Yes, I suppose it is the easiest way to add the img_callback in stable-diffusion.
And now I understand the feature of cycling prompts. I will try to figure out a more general API for supporting this as soon as possible. Thank you for the examples!
Thank you for your hard work.
I do have another question, is it possible to support the img2img function of stable diffusion? I can see that you have implemented stochastic_encode() but scripts/img2img.py also requires a decode() function, and I haven't been able to understand how to implement it myself.
Thanks
No problem. I will support it soon.
I've come back to this because today I realised that actually it's correcting_x0_fn which is preferable for this "preview" (to support the img_callback in stable diffusion in the same way as DDIM).
xt shows a preview including the noise, x0 without it.
Just leaving a note in case anyone else is implementing it for themselves.