denoising-diffusion-pytorch icon indicating copy to clipboard operation
denoising-diffusion-pytorch copied to clipboard

Question about `p_sample` :)

Open tryumanshow opened this issue 2 years ago • 2 comments

Hi! Always thank you for your great codes that you provide!

Anyway, there are 2 points that I can't understand.

  1. Why do you use posterior mean and variance on reverse step? ( in p_sample function ) I expected using the equation (11) of original DDPM paper, but I think it is not on this code. Can you explain this for me ? :)
    def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
        preds = self.model_predictions(x, t, x_self_cond)
        x_start = preds.pred_x_start

        if clip_denoised:
            x_start.clamp_(-1., 1.)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
        return model_mean, posterior_variance, posterior_log_variance, x_start

    @torch.no_grad()
    def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
        b, *_, device = *x.shape, x.device # b: 4
        batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
        model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, # Isotrophic Normal Gaussian
                                                                          t = batched_times, 
                                                                          x_self_cond = x_self_cond, 
                                                                          clip_denoised = clip_denoised)  #  <------ This part!
        noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
        pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
        return pred_img, x_start
  1. Can you explain the intent of (0.5 * model_log_variance).exp() on `pred_img = model_mean + (0.5 * model_log_variance).exp() * noise in p_sample_loop?

The full code is as below:

    @torch.no_grad()
    def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
        b, *_, device = *x.shape, x.device 
        batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
        model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, # Isotrophic Normal Gaussian
                                                                          t = batched_times, 
                                                                          x_self_cond = x_self_cond, 
                                                                          clip_denoised = clip_denoised)
        noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
        pred_img = model_mean + (0.5 * model_log_variance).exp() * noise  # <------------ This part!
        return pred_img, x_start

tryumanshow avatar Dec 06 '22 04:12 tryumanshow

I got the same question

In my mind, posterior_variance is the one that we need. as 'pred_img = x_start + posterior_variance * noise'

lhaippp avatar Dec 16 '22 11:12 lhaippp

1: Step 4 from Algorithm, but using equivalence from formula 9,7. Equation 9 right side of mean_tilde_t( ... , HERE) and replacing mean_tilde_t with 7 I tested equation 11 and came across stability issues. I think that small numbers cause some floating point errors. This also enables the clamping of the image.

2: There is a commentary somewhere "# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain". I can't remember if this comment is from this repository. This is probably again a numerical fix.

Note: sqrt(posterior_variance) == exp(0.5*log(posterior_variance)) posterior_variance stores the square of σ

robert-graf avatar Dec 20 '22 14:12 robert-graf