dpm-solver
dpm-solver copied to clipboard
Noisy results with "order == 1" (trying to replicate DDIM resutls)
Hi authors,
Thank you for the nice paper and clear code and documentation!!
I am trying DPM-Solver in my project for sampling acceleration. Previously, I can obtain reasonable results with DDIM (step=10, 100, ...), but the results I obtained with dpm-solver are pretty bad. Could you give some suggestions on the implementation?
Here are the details of my model: (1) Training: DDPM ( L1 Loss, predict noise), T=1000, UNet with additional condition inputs, trained on audio data. (2) Beta schedules: Sigmoid schedule (according to(https://arxiv.org/abs/2212.11972))
Code snippet that uses DPM-solver in my project:
self.betas = sigmoid_beta_schedule(timesteps=1000)
self.noise_schedule = NoiseScheduleVP(schedule='discrete', betas=self.betas)
self.model_fn = model_wrapper(
self.net,
self.noise_schedule,
model_type="noise", # or "x_start" or "v" or "score"
model_kwargs={},
)
self.dpm_solver = DPM_Solver(self.model_fn, self.noise_schedule, algorithm_type="dpmsolver")
After the definition:
x_T = torch.randn(input.shape, device = "cuda")
pred = self.dpm_solver.sample(
x_T,
condition,
steps=20,
order=1,
skip_type="time_uniform",
method="singlestep",
)
pred = unnormalize_to_zero_to_one(pred)
Thanks a lot!
Hi @WikiChao , does your code contain this line?: https://github.com/LuChengTHU/dpm-solver/blob/main/dpm_solver_pytorch.py#L105
If so, could you please print the first 5 and last 5 items of log_alphas?
Thanks for the prompt reply! The trick did help, it seems I am using the previous version and missing such a line of code.
The results make sense now, but they are still worse than DDIM. I have tried different settings, e.g., "multistep" or "single step", "order = 2 or 3", "step = 10 to 100", but cannot beat DDIM. Are there any tricks in choosing hyperparameters, for example, clipping log-SNR by different values?
Thanks a lot!
Chao
Hi @WikiChao ,
"but they are still worse than DDIM": In fact, order=1 is exactly the DDIM. You can try to reproduce the results of DDIM by manually setting the timestep in https://github.com/LuChengTHU/dpm-solver/blob/5c6ee9f1e6b60c8c54f955fbaab0a6717fc2b75b/dpm_solver_pytorch.py#L453 as the same as your DDIM code to check which part is missing. I guess you need to tune timestep carefully.