denoising-diffusion-pytorch
denoising-diffusion-pytorch copied to clipboard
Question about `p_sample` :)
Hi! Always thank you for your great codes that you provide!
Anyway, there are 2 points that I can't understand.
- Why do you use posterior mean and variance on reverse step? ( in
p_sample
function ) I expected using the equation (11) of original DDPM paper, but I think it is not on this code. Can you explain this for me ? :)
def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
preds = self.model_predictions(x, t, x_self_cond)
x_start = preds.pred_x_start
if clip_denoised:
x_start.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad()
def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
b, *_, device = *x.shape, x.device # b: 4
batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, # Isotrophic Normal Gaussian
t = batched_times,
x_self_cond = x_self_cond,
clip_denoised = clip_denoised) # <------ This part!
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
return pred_img, x_start
- Can you explain the intent of
(0.5 * model_log_variance).exp()
on `pred_img = model_mean + (0.5 * model_log_variance).exp() * noise in p_sample_loop?
The full code is as below:
@torch.no_grad()
def p_sample(self, x, t: int, x_self_cond = None, clip_denoised = True):
b, *_, device = *x.shape, x.device
batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, # Isotrophic Normal Gaussian
t = batched_times,
x_self_cond = x_self_cond,
clip_denoised = clip_denoised)
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise # <------------ This part!
return pred_img, x_start
I got the same question
In my mind, posterior_variance
is the one that we need.
as 'pred_img = x_start + posterior_variance * noise'
1: Step 4 from Algorithm, but using equivalence from formula 9,7. Equation 9 right side of mean_tilde_t( ... , HERE) and replacing mean_tilde_t with 7 I tested equation 11 and came across stability issues. I think that small numbers cause some floating point errors. This also enables the clamping of the image.
2: There is a commentary somewhere "# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain". I can't remember if this comment is from this repository. This is probably again a numerical fix.
Note: sqrt(posterior_variance) == exp(0.5*log(posterior_variance)) posterior_variance stores the square of σ