diffusers
diffusers copied to clipboard
feat : add log-rho deis multistep scheduler
Add deis multistep methods
- I modified the polynomial fitting formula in log-rho space instead of the original linear t space in DEIS paper. The modification enjoys closed-form coefficients for exponential multistep update instead of relying on the numerical solver.
- I can add other variants if we have
vmap
andgrad
support in torch.
Compared with dpm-solver on stable diffusion, note the DEIS in the table is the t-AB-DEIS instead of log-rho-AB-DEIS, where the former one fits a polynomial in t space while the second one fits a polynomial in log-rho space.
https://arxiv.org/pdf/2211.01095.pdf
The documentation is not available anymore as the PR was closed or merged.
Hi @qsh-zh ,
The comparison seems not fair because I used your t-AB-DEIS
in my paper :(
Is the log-rho-DEIS
better than t-AB-DEIS
or DPM-Solver? (On CIFAR-10 it seems that log-rho-DEIS
is worse than DPM-Solver?)
@LuChengTHU Thanks for your discussion. You are right, the number in the table is for t-AB-DEIS.
Actually, inspired by your work and variational diffusion models paper, I find fitting a polynomial in log-rho~(lambda in your work) is more reasonable compared with fitting with $\rho$ for image diffusion models. Empirically, the rho-AB-DEIS is worse than DPM-Solver. After switching to log-rho space, log-rho-AB-DEIS and DPM-Solver++ multistep has similar performance in stable diffusion in my preliminary experiments. I should have more empirical experiments to test them.
I initially wanted to merge t-AB method, not sure whether the diffusers library supports jax in its pytorch code since t-AB requires a numerical integrator, and vmap
and grad
tools can make the implementation much easier.
I modify the pr msg to clarify the point and look forward merge other variants into the codebase and having more comparisons between deis and dpm-solvers.
BTW, I would like to have more discussions in person if you are in Neurips. Let me know if you are available~(finger crossed).
Best, Qinsheng
Hi @qsh-zh ,
Thank you for modifying the message. I'm sorry that I find the above comments I left seem to be too rude. I must say I apologize for my being rude :( . Actually, I'm a big fan of many of your papers (DEIS, gDDIM, PIS, DiffFlow). It's my pleasure to discuss with you!
It's very interesting to know that log-rho-AB-DEIS
and DPM-Solver++
share many similarities! I'm curious if such findings can bring some new insights to the community or even encourage some new faster solvers :)
In addition, a small suggestion: maybe you can first implement the JAX version for DEIS? Because many users also use JAX version of Diffusers and I believe it is worth to merge your excellent t-AB
version into Diffusers :)
And It is a pity that I cannot go to Neurips because of the Covid-19 policy in China... But I'm also eager to discuss with excellent ML researchers like you. I've sent a email to you and I'm sincerely looking forward to your reply!
Best, Cheng
Hey @qsh-zh,
Thanks a lot for the PR! Does it already work with stable diffusion? :-)
@LuChengTHU We can chat more about research and collaborations.
@patrickvonplaten I tested it locally, and it works with stable diffusion v1.x. The codebase is heavily based on @LuChengTHU DPM-Solver++ multistep and the main difference lies in multistep coefficients. I need to validate it with v2.0 and report some qualitative results in this thread.
Hey @qsh-zh,
anything we can help with or should we merge as is for now?
@patrickvonplaten I tested it locally. It works for the stable diffusion pipe.
Compare against dpm++ multistep
setup
Following the setup introduced in dpm++ paper, I generate latent code with dpm++ and deis based on the first 200 prompts from coco and the same random noise. The "ground truth" latent codes for those random noises are based on 999 step ddim~(the paper use 1000 steps originally, which I encountered a numerical issue in diffusers codebase. I expect 999 step ddim and 1000 step ddim should be small.). In the last, I compare the square L2 distance between generated samples with each scheduler and "ground truth" latent codes.
performance
NFE | 10 | 15 | 20 | 30 | 50 |
DEIS | 25.93 | 21.25 | 15.39 | 8.77 | 3.89 |
DPM++ | 25.37 | 22.49 | 15.95 | 8.80 | 3.88 |
Some comments
- Both DEIS and DPM / DPM++ are based on
variation of constants formula
. They are different in how to approximate nonlinear term (neural networks) in the diffusion ODE.- DEIS: Multistep method, extrapolating a polynomial in previous network evaluations. (epsilon-based)
- DPM: Single step method, constructing a polynomial based on multiple network evaluations in single step. (epsilon based)
- DPM++ multistep: constructing a polynomial based on previous network evaluations. The high-order gradient of the polynomial is decided by finite differences on previous network evaluations. (x0 based)
code to reproduce comparison
# !pip insall jammy
# !modify pipe_stable_diffusion.py to return latent code instead of decoded image
import torch
from src.diffusers.schedulers.scheduling_deis_multistep import DEISMultistepScheduler
from src.diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
from src.diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
import random
import numpy as np
import jammy.io as jio
dpm = DPMSolverMultistepScheduler(
beta_start=0.00085,
beta_end =0.012,
solver_order=2,
predict_epsilon=True,
thresholding=False,
algorithm_type="dpmsolver++",
solver_type="midpoint",
lower_order_final=True,
)
deis = DEISMultistepScheduler(
beta_start=0.00085,
beta_end =0.012,
ab_order=2,
predict_epsilon=True,
thresholding=False,
algorithm_type="deis",
lower_order_final=True,
)
ddim = DEISMultistepScheduler(
beta_start=0.00085,
beta_end =0.012,
ab_order=1,
predict_epsilon=True,
thresholding=False,
algorithm_type="deis",
lower_order_final=True,
)
model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16")
pipe = pipe.to(device)
import urllib.request, json
with urllib.request.urlopen("<https://raw.githubusercontent.com/tylin/coco-caption/master/results/captions_val2014_fakecap_results.json>") as url:
coco_text = json.load(url)
coco_prompt = [item['caption'] for item in coco_text]
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def generate_fn(seed, prompts, scheduler, num_inference_steps, batch_size=20):
pipe.scheduler = scheduler
len_prompts = len(prompts)
groups = num_to_groups(len_prompts, batch_size)
imgs = []
counter = 0
for cur_batch_size in groups:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
_prompts = prompts[counter:counter+cur_batch_size]
counter = counter+cur_batch_size
_imgs = pipe(_prompts, num_inference_steps=num_inference_steps, output_type = "latent")
imgs.append(_imgs)
return torch.cat(imgs)
num_text = 200
# generate ddim gt
ddim_img = generate_fn(1, coco_prompt[:num_text], ddim, 999, 20)
jio.dump(f'coco{num_text}_ddim999.pth', ddim_img)
# deis and dpm
def deis_dpm_fn(prompts, num_inference_steps, batch_size=20):
deis_img = generate_fn(1, prompts, deis, num_inference_steps, batch_size)
dpm_img = generate_fn(1, prompts, dpm, num_inference_steps, batch_size)
return deis_img, dpm_img
for step in [10, 15, 20, 30, 50]:
deis_img, dpm_img = deis_dpm_fn(coco_prompt[:num_text], step)
jio.dump(
f'coco{num_text}_deis{step}.pth', deis_img
)
jio.dump(
f'coco{num_text}_dpm{step}.pth', dpm_img
)
# compare two methods
ddim_gt = torch.tensor(
jio.load(f'coco{num_text}_ddim999.pth'), dtype=torch.float32
)
diff_mean = {
'dpm' : [],
'deis' : [],
}
diff_std = {
'dpm' : [],
'deis' : [],
}
for step in [10, 15, 20, 30, 50]:
deis_img = jio.load(f'coco{num_text}_deis{step}.pth')
diff = torch.abs(
ddim_gt - deis_img
)**2
diff_mean['deis'].append(
diff.mean().item()
)
diff_std['deis'].append(
diff.std().item()
)
dpm_img = jio.load(f'coco{num_text}_dpm{step}.pth')
diff = torch.abs(
ddim_gt - dpm_img
)**2
diff_mean['dpm'].append(
diff.mean().item()
)
diff_std['dpm'].append(
diff.std().item()
)
Thanks a mille for the PR @qsh-zh,
I've just added some tests and docs and I think now we can merge it!
Here an example of how you can run the scheduler in diffusers
:
from diffusers import StableDiffusionPipeline, DEISMultistepScheduler
import torch
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
generator = torch.Generator(device="cuda").manual_seed(0)
image = pipe(prompt, generator=generator, num_inference_steps=25).images[0]