improved-diffusion
improved-diffusion copied to clipboard
p_mean_variance mean calculation
I was looking through the code to see how the paper was implemented, but I ran into an issue when looking at the part of the paper measuring the KL loss between two Gaussians:

Specifically, the Loss at time t-1 is the KL loss between the predicted gaussian and real gaussian at time t-1. The predicted gaussian is defined as follows:

And the real gaussian is defined as follows:

The formulation of the loss function makes sense to me, but when I look at the code, it looks like the authors are having the model predict mu_tilde (eq 11) as opposed to mu (eq 13). I'm looking at the following function in the code: https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L232
In this function, the mean is calculated from epsilon by first calculating the prediction for x_0, then calculating the mean at time t.

To predict x_0, the following function is used: https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/gaussian_diffusion.py#L328
But, this function looks to be the formulation for the mean function (eq 13)
I have a couple of questions regarding the implementation:
- Why is the mean function (eq 11) for the real gaussian distribution (eq 12) being used when retrieving the value of the predicted gaussian distribution (eq 3) when the formulation for the predicted gaussian distribution is formulated as a function of eq 13?
- Why is x_0/x_start being calculated directly from eq 13, the predicted mean?
Thanks for the help!
I think they follow the original DDPM implementation. By predicting x_0 first, they can apply value clipping to x_0. This is a trick not mentioned in the paper to improve sampling quality. Then from x_0 we can do posterior sampling to get x_{t-1}.
Without value clipping, the two are the same (either go directly from x_t to x_{t-1} using Eq 13, or go x_0 first and then x_{t-1}). See the discussions in https://github.com/hojonathanho/diffusion/issues/5
Oh yeah, that makes sense! As I've learned more about diffusion models, it looks like predicting x_0 produces better results as one can skip steps like in DDIM.