denoising-diffusion-pytorch icon indicating copy to clipboard operation
denoising-diffusion-pytorch copied to clipboard

Technical question about sampling function

Open dome272 opened this issue 2 years ago • 2 comments

Hi,

I was wondering why every diffusion models implementation uses this specific sampling procedure? When I take a look at the DDPM paper they show the sampling algorithm to be: algorithm_sampling

However, it seems that no implementation follows that and rather takes a really complicated route of first predicting the noise, then calculating x_0, then the mean and logvariance and then construct x_t-1 from that.

I implemented the above algorithm while using your codebase:

@torch.no_grad()
    def my_sample(self, n):
        x = torch.randn((n, 3, self.image_size, self.image_size)).to(self.device)
        for i in tqdm(reversed(range(1, self.num_timesteps)), position=0):
            t = (torch.ones(n) * i).long().to(self.device)
            predicted_noise = self.denoise_fn(x, t)
            beta = extract(self.betas, t, x.shape)
            alpha_hat = extract(self.alphas_cumprod, t, x.shape)
            alpha = 1. - beta
            if i > 1:
                noise = torch.randn_like(x)
            else:
                noise = torch.zeros_like(x)
            x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
            x = x.clamp(-1., 1.)
        return x.add(1).mul(0.5)

But the results are just gray images with a bit of shape and colour: (top is the normal sampling, like your code, bottom is using the above sampling function) image

Do you have any idea why this kind of sampling does not work?

dome272 avatar Jul 01 '22 19:07 dome272

This is weird. (predicted noise -> x_0 -> x_t-1) uses eq. 9, and your implementation uses eq.10. I've verified and they are mathematically equivalent.

I'd suggest you to check the first few iterations (largest t) to see if the two routines produce very similar numbers.

askerlee avatar Jul 05 '22 05:07 askerlee

Hey @dome272 , I am not sure why the code framework here does not work with the equation you referenced in the paper, and I have not had time to look into in depth, but someone else developed a really nice google colab file that implements the DDPM algorithm step-by-step similar to this code-base, and they do use the equation in the algorithm you referenced above. I have tested their code myself, and it gives good-looking outputs, so I think it could indicate that some detail is not correct with this implementation of the DDPM paper. Linked below is the google colab framework I referenced earlier, feel free to try/experiment with it yourself.

https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/annotated_diffusion.ipynb

malekinho8 avatar Jul 20 '22 05:07 malekinho8