stable-diffusion icon indicating copy to clipboard operation
stable-diffusion copied to clipboard

BUG of Proper Timestep Alignment in Sampling Methods such as DDIM

Open kwonminki opened this issue 1 year ago • 4 comments

BUG of Proper Timestep Alignment in Sampling Methods such as DDIM

Upon careful observation, I have discovered a significant issue that may appear trivial but is, in fact, anything but trivial. It concerns the determination of timesteps when using sampling methods such as DDIM. The list of timesteps generated during DDIM sampling is as follows:

tensor([981, 961, 941, 921, 901, 881, 861, 841, 821, 801, 781, 761, 741, 721, 701, 681, 661, 641, 621, 601, 581, 561, 541, 521, 501, 481, 461, 441, 421, 401, 381, 361, 341, 321, 301, 281, 261, 241, 221, 201, 181, 161, 141, 121, 101, 81, 61, 41, 21, 1], device='cuda:1')

At first glance, one may wonder why the list starts with 981 instead of 1000, which was the value of T during model training. Even though the sampling may still produce satisfactory results despite the offset, it is noteworthy that the timestep is not aligned with the designed one. This misalignment can potentially lead to significant issues, as I discovered during my attempt to solve DDIM inversion. Although the existing DDIM inversion method works to some extent, some images may still be severely damaged.

To illustrate this issue, I attempted to invert an image of a cute puppy using the standard DDIM timestep list. 15 I cropped it for resolution.

x0_gen-DDIMinversion-Various-9-inv_-for_original And it is the result of inversion with previous timesteps.
The result showed that my cute puppy went somewhere and something strange popped out. However, upon setting the starting point of the timestep to 999, the inversion result was excellent. x0_gen-DDIMinversion-Various-9-inv_-for_custom

Here is the list of timesteps.

tensor([999, 979, 958, 938, 918, 897, 877, 856, 836, 816, 795, 775, 755, 734, 714, 693, 673, 653, 632, 612, 592, 571, 551, 531, 510, 490, 469, 449, 429, 408, 388, 368, 347, 327, 307, 286, 266, 245, 225, 205, 184, 164, 144, 123, 103,  82,  62,  42,  21,   1], device='cuda:1')

And the code I used.

def custom_set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None, inversion_flag: bool = False):
    """
    Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

    Args:
        num_inference_steps (int):
            the number of diffusion steps used when generating samples with a pre-trained model.
    """

    if num_inference_steps > self.config.num_train_timesteps:
        raise ValueError(
            f"num_inference_steps: {num_inference_steps} cannot be larger than self.config.train_timesteps:"
            f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
            f" maximal {self.config.num_train_timesteps} timesteps."
        )

    self.num_inference_steps = num_inference_steps
    step_ratio = self.config.num_train_timesteps // self.num_inference_steps
    # creates integer timesteps
    timesteps = np.linspace(0, 1, num_inference_steps) * (self.config.num_train_timesteps-2) # T=999
    timesteps = timesteps + 1e-6
    timesteps = timesteps.round().astype(np.int64)
    # reverse timesteps except for inverse diffusion
    if not inversion_flag:
        timesteps = np.flip(timesteps).copy()

    self.timesteps = torch.from_numpy(timesteps).to(device)
    self.timesteps += self.config.steps_offset

I will not delve into the reasons for the importance of the 999 timestep in inversion as it would require a lengthy explanation. However, I will say that if the DDIM timestep does not start at 999, not only the inversion result, but also the text2img sampling performance, will suffer.

This is because the sampling process starts from a random Gaussian distribution, which can significantly affect the final result.

kwonminki avatar Apr 05 '23 06:04 kwonminki

Hi, I agree with you that these timesteps seem wrong, and looked into it a little bit recently. But I want to note that the original implementation for the DDIM paper also determines them in the same way, see here. (If I don't misunderstand their code at least.) The same is true for huggingface's diffusers library. So it is not a "bug" that is introduced here.

I am not aware of research yet that compares these two approaches on a more objective basis. Nonetheless, I still think this is worth investigating. When I have run it, I did find a particular large disconnect when run with few timesteps (3-5, here) and having a starting point that was a lot closer to 1000 than would be chosen otherwise. So right now the next steps would be to launch a quantitative fid based study of this. I might do that in the future, if nobody else beats me to it.

jenuk avatar Apr 06 '23 11:04 jenuk

x0_gen-DDIMforward-for_a photograph of a hamburger 50 step with your prompt : a photograph of a hamburger, highly detailed, 4k

x0_gen-DDIMforward-for_a photograph of a hamburger_3_original 3 step with original timesteps ([667, 334, 1])

x0_gen-DDIMforward-for_a photograph of a hamburger_3_mine 3 step with my timesteps ([999, 500, 1])

I did find a particular large disconnect when run with few timesteps too, and It's better if I used the modified timesteps.

kwonminki avatar Apr 06 '23 22:04 kwonminki

I want to share more experiments.

x0_gen-DDIMforward-for_a photograph of a hamburger_10step DDIM, 10 step sampling with previous timesteps.

x0_gen-DDIMforward-for_a photograph of a hamburger_10step_mine DDIM, 10 step sampling with my modified timesteps.

I think it is one of the crucial observations.

kwonminki avatar Apr 07 '23 01:04 kwonminki

Thanks for sharing this ! Another concern I have is the current DDIM sampler seems to stop at X_1 instead of X_0 at the end of sampling since there's a 1 offset in the DDIM time scheduler. Why do they choose to implement in this way?

Morefre avatar Aug 24 '23 12:08 Morefre