dpm-solver icon indicating copy to clipboard operation
dpm-solver copied to clipboard

Noisy results with "order == 1" (trying to replicate DDIM resutls)

Open WikiChao opened this issue 2 years ago • 3 comments
trafficstars

Hi authors,

Thank you for the nice paper and clear code and documentation!!

I am trying DPM-Solver in my project for sampling acceleration. Previously, I can obtain reasonable results with DDIM (step=10, 100, ...), but the results I obtained with dpm-solver are pretty bad. Could you give some suggestions on the implementation?

Here are the details of my model: (1) Training: DDPM ( L1 Loss, predict noise), T=1000, UNet with additional condition inputs, trained on audio data. (2) Beta schedules: Sigmoid schedule (according to(https://arxiv.org/abs/2212.11972))

Code snippet that uses DPM-solver in my project:

    self.betas = sigmoid_beta_schedule(timesteps=1000)
    self.noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas)
    self.model_fn = model_wrapper(
        self.net,
        self.noise_schedule,
        model_type="noise",  # or "x_start" or "v" or "score"
        model_kwargs={},
    )
    self.dpm_solver = DPM_Solver(self.model_fn, self.noise_schedule, algorithm_type="dpmsolver")

After the definition:

    x_T = torch.randn(input.shape, device = "cuda")
    pred = self.dpm_solver.sample(
        x_T,
        condition,
        steps=20,
        order=1,
        skip_type="time_uniform",
        method="singlestep",
    )
   pred = unnormalize_to_zero_to_one(pred)

Thanks a lot!

WikiChao avatar Apr 24 '23 04:04 WikiChao

Hi @WikiChao , does your code contain this line?: https://github.com/LuChengTHU/dpm-solver/blob/main/dpm_solver_pytorch.py#L105

If so, could you please print the first 5 and last 5 items of log_alphas?

LuChengTHU avatar Apr 24 '23 10:04 LuChengTHU

Thanks for the prompt reply! The trick did help, it seems I am using the previous version and missing such a line of code.

The results make sense now, but they are still worse than DDIM. I have tried different settings, e.g., "multistep" or "single step", "order = 2 or 3", "step = 10 to 100", but cannot beat DDIM. Are there any tricks in choosing hyperparameters, for example, clipping log-SNR by different values?

Thanks a lot!

Chao

WikiChao avatar Apr 25 '23 19:04 WikiChao

Hi @WikiChao ,

"but they are still worse than DDIM": In fact, order=1 is exactly the DDIM. You can try to reproduce the results of DDIM by manually setting the timestep in https://github.com/LuChengTHU/dpm-solver/blob/5c6ee9f1e6b60c8c54f955fbaab0a6717fc2b75b/dpm_solver_pytorch.py#L453 as the same as your DDIM code to check which part is missing. I guess you need to tune timestep carefully.

LuChengTHU avatar Apr 26 '23 15:04 LuChengTHU