diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

feat : add log-rho deis multistep scheduler

Open qsh-zh opened this issue 1 year ago • 7 comments

Add deis multistep methods

  • I modified the polynomial fitting formula in log-rho space instead of the original linear t space in DEIS paper. The modification enjoys closed-form coefficients for exponential multistep update instead of relying on the numerical solver.
  • I can add other variants if we have vmap and grad support in torch.

#280

Compared with dpm-solver on stable diffusion, note the DEIS in the table is the t-AB-DEIS instead of log-rho-AB-DEIS, where the former one fits a polynomial in t space while the second one fits a polynomial in log-rho space.

https://arxiv.org/pdf/2211.01095.pdf

image

qsh-zh avatar Nov 26 '22 06:11 qsh-zh

The documentation is not available anymore as the PR was closed or merged.

Hi @qsh-zh ,

The comparison seems not fair because I used your t-AB-DEIS in my paper :(

Is the log-rho-DEIS better than t-AB-DEIS or DPM-Solver? (On CIFAR-10 it seems that log-rho-DEIS is worse than DPM-Solver?)

LuChengTHU avatar Nov 29 '22 07:11 LuChengTHU

@LuChengTHU Thanks for your discussion. You are right, the number in the table is for t-AB-DEIS.

Actually, inspired by your work and variational diffusion models paper, I find fitting a polynomial in log-rho~(lambda in your work) is more reasonable compared with fitting with $\rho$ for image diffusion models. Empirically, the rho-AB-DEIS is worse than DPM-Solver. After switching to log-rho space, log-rho-AB-DEIS and DPM-Solver++ multistep has similar performance in stable diffusion in my preliminary experiments. I should have more empirical experiments to test them.

I initially wanted to merge t-AB method, not sure whether the diffusers library supports jax in its pytorch code since t-AB requires a numerical integrator, and vmap and grad tools can make the implementation much easier.

I modify the pr msg to clarify the point and look forward merge other variants into the codebase and having more comparisons between deis and dpm-solvers.

BTW, I would like to have more discussions in person if you are in Neurips. Let me know if you are available~(finger crossed).

Best, Qinsheng

qsh-zh avatar Nov 30 '22 07:11 qsh-zh

Hi @qsh-zh ,

Thank you for modifying the message. I'm sorry that I find the above comments I left seem to be too rude. I must say I apologize for my being rude :( . Actually, I'm a big fan of many of your papers (DEIS, gDDIM, PIS, DiffFlow). It's my pleasure to discuss with you!

It's very interesting to know that log-rho-AB-DEIS and DPM-Solver++ share many similarities! I'm curious if such findings can bring some new insights to the community or even encourage some new faster solvers :)

In addition, a small suggestion: maybe you can first implement the JAX version for DEIS? Because many users also use JAX version of Diffusers and I believe it is worth to merge your excellent t-AB version into Diffusers :)

And It is a pity that I cannot go to Neurips because of the Covid-19 policy in China... But I'm also eager to discuss with excellent ML researchers like you. I've sent a email to you and I'm sincerely looking forward to your reply!

Best, Cheng

LuChengTHU avatar Nov 30 '22 15:11 LuChengTHU

Hey @qsh-zh,

Thanks a lot for the PR! Does it already work with stable diffusion? :-)

patrickvonplaten avatar Dec 01 '22 17:12 patrickvonplaten

@LuChengTHU We can chat more about research and collaborations.

@patrickvonplaten I tested it locally, and it works with stable diffusion v1.x. The codebase is heavily based on @LuChengTHU DPM-Solver++ multistep and the main difference lies in multistep coefficients. I need to validate it with v2.0 and report some qualitative results in this thread.

qsh-zh avatar Dec 02 '22 05:12 qsh-zh

Hey @qsh-zh,

anything we can help with or should we merge as is for now?

patrickvonplaten avatar Dec 20 '22 01:12 patrickvonplaten

@patrickvonplaten I tested it locally. It works for the stable diffusion pipe.

Compare against dpm++ multistep

setup

Following the setup introduced in dpm++ paper, I generate latent code with dpm++ and deis based on the first 200 prompts from coco and the same random noise. The "ground truth" latent codes for those random noises are based on 999 step ddim~(the paper use 1000 steps originally, which I encountered a numerical issue in diffusers codebase. I expect 999 step ddim and 1000 step ddim should be small.). In the last, I compare the square L2 distance between generated samples with each scheduler and "ground truth" latent codes.

performance

NFE 10 15 20 30 50
DEIS 25.93 21.25 15.39 8.77 3.89
DPM++ 25.37 22.49 15.95 8.80 3.88

Some comments

  • Both DEIS and DPM / DPM++ are based on variation of constants formula. They are different in how to approximate nonlinear term (neural networks) in the diffusion ODE.
    • DEIS: Multistep method, extrapolating a polynomial in previous network evaluations. (epsilon-based)
    • DPM: Single step method, constructing a polynomial based on multiple network evaluations in single step. (epsilon based)
    • DPM++ multistep: constructing a polynomial based on previous network evaluations. The high-order gradient of the polynomial is decided by finite differences on previous network evaluations. (x0 based)

code to reproduce comparison

# !pip insall jammy
# !modify pipe_stable_diffusion.py to return latent code instead of decoded image
import torch
from src.diffusers.schedulers.scheduling_deis_multistep import DEISMultistepScheduler
from src.diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
from src.diffusers.schedulers.scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
import random
import numpy as np
import jammy.io as jio

dpm = DPMSolverMultistepScheduler(
    beta_start=0.00085,
    beta_end =0.012,
    solver_order=2,
    predict_epsilon=True,
    thresholding=False,
    algorithm_type="dpmsolver++",
    solver_type="midpoint",
    lower_order_final=True,
)

deis = DEISMultistepScheduler(
    beta_start=0.00085,
    beta_end =0.012,
    ab_order=2,
    predict_epsilon=True,
    thresholding=False,
    algorithm_type="deis",
    lower_order_final=True,
)

ddim = DEISMultistepScheduler(
    beta_start=0.00085,
    beta_end =0.012,
    ab_order=1,
    predict_epsilon=True,
    thresholding=False,
    algorithm_type="deis",
    lower_order_final=True,
)

model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16")
pipe = pipe.to(device)


import urllib.request, json
with urllib.request.urlopen("<https://raw.githubusercontent.com/tylin/coco-caption/master/results/captions_val2014_fakecap_results.json>") as url:
    coco_text = json.load(url)
coco_prompt = [item['caption'] for item in coco_text]

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

def generate_fn(seed, prompts, scheduler, num_inference_steps, batch_size=20):
    pipe.scheduler = scheduler

    len_prompts = len(prompts)
    groups = num_to_groups(len_prompts, batch_size)
    imgs = []
    counter = 0
    for cur_batch_size in groups:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        _prompts = prompts[counter:counter+cur_batch_size]
        counter = counter+cur_batch_size
        _imgs = pipe(_prompts, num_inference_steps=num_inference_steps, output_type = "latent")
        imgs.append(_imgs)
    return torch.cat(imgs)

num_text = 200
# generate ddim gt
ddim_img = generate_fn(1, coco_prompt[:num_text], ddim, 999, 20)
jio.dump(f'coco{num_text}_ddim999.pth', ddim_img)

# deis and dpm
def deis_dpm_fn(prompts, num_inference_steps, batch_size=20):
    deis_img = generate_fn(1, prompts, deis, num_inference_steps, batch_size)
    dpm_img = generate_fn(1, prompts, dpm, num_inference_steps, batch_size)

    return deis_img, dpm_img
for step in [10, 15, 20, 30, 50]:
    deis_img, dpm_img = deis_dpm_fn(coco_prompt[:num_text], step)
    jio.dump(
        f'coco{num_text}_deis{step}.pth', deis_img
    )

    jio.dump(
        f'coco{num_text}_dpm{step}.pth', dpm_img
    )

# compare two methods
ddim_gt = torch.tensor(
    jio.load(f'coco{num_text}_ddim999.pth'), dtype=torch.float32
)
diff_mean = {
    'dpm' : [],
    'deis' : [],
}
diff_std = {
    'dpm' : [],
    'deis' : [],
}

for step in [10, 15, 20, 30, 50]:
    deis_img = jio.load(f'coco{num_text}_deis{step}.pth')
    diff = torch.abs(
            ddim_gt - deis_img
        )**2
    diff_mean['deis'].append(
        diff.mean().item()
    )
    diff_std['deis'].append(
        diff.std().item()
    )

    dpm_img = jio.load(f'coco{num_text}_dpm{step}.pth')
    diff = torch.abs(
            ddim_gt - dpm_img
        )**2
    diff_mean['dpm'].append(
        diff.mean().item()
    )
    diff_std['dpm'].append(
        diff.std().item()
    )

qsh-zh avatar Jan 02 '23 04:01 qsh-zh

Thanks a mille for the PR @qsh-zh,

I've just added some tests and docs and I think now we can merge it!

Here an example of how you can run the scheduler in diffusers:

from diffusers import StableDiffusionPipeline, DEISMultistepScheduler
import torch

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.scheduler = DEISMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
generator = torch.Generator(device="cuda").manual_seed(0)
image = pipe(prompt, generator=generator, num_inference_steps=25).images[0]

patrickvonplaten avatar Jan 04 '23 22:01 patrickvonplaten