denoising-diffusion-pytorch
denoising-diffusion-pytorch copied to clipboard
DDIM sampler for Continuous time Gaussian diffusion?
Hello,
is possible to use the DDIM sampler when the time is continuous (continuous_time_gaussian_diffusion.py)? Please can you provide a simple example of code?
I try to implement it like this (scheduling_ddim.py):
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
# DDIM sampler
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
def ddim_sample(
self,
unet,
image,
t,
class_cond,
cond_scale = 1.,
eta = 0.0,
variance_noise = None,
t_next = None,
):
pred = unet.call_with_cond_scale(
image,
self.log_snr(t),
class_cond,
cond_scale = cond_scale,
)
# 1. compute alphas, betas
log_snr = self.log_snr(t)
log_snr_next = self.log_snr(t_next)
log_snr, log_snr_next = map(partial(right_pad_dims_to, x_t), (log_snr, log_snr_next))
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
alpha_next, sigma_next = log_snr_to_alpha_sigma(log_snr_next)
# alpha_prod_t = self.alphas_cumprod[timestep]
# alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
# beta_prod_t = 1 - alpha_prod_t
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if unet.pred_objective == 'noise':
pred_original_sample = self.predict_start_from_noise(image, t = t, noise = pred)
elif unet.pred_objective == 'x_start':
pred_original_sample = pred
elif unet.pred_objective == 'v':
pred_original_sample = self.predict_start_from_v(image, t = t, v = pred)
else:
raise ValueError(f'unknown objective {unet.pred_objective}')
# 3. Clip or threshold "predicted x_0"
pred_original_sample = tf.clip_by_value(pred_original_sample, -1.0, 1.0)
# 4. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
# variance = self._get_variance(timestep, prev_timestep)
std_dev_t = eta * sigma # variance ** (0.5)
# 5. the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (image - alpha * pred_original_sample) / sigma # alpha_prod_t ** (0.5), beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_next**2 - std_dev_t**2) ** (0.5) * pred_epsilon # alpha_prod_t_prev
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample = alpha * pred_original_sample + pred_sample_direction # alpha_prod_t_prev ** (0.5),
if eta > 0:
if variance_noise is None:
variance_noise = tf.random.normal(
tf.shape(pred), dtype=pred.dtype
)
variance = std_dev_t * variance_noise
prev_sample = prev_sample + variance
return prev_sample, pred_original_sample
The only issue with this implementation is how calculate the self.alphas_cumprod
. Or this is not a good way of thinking?
Thanks.
@lucidrains Hi, is this implementation above equivalent of DDIM sampler for continuous time?
Hi @markub3327 did you ever manage to implement DDIM for continuous time? I am also interested in this
@lucidrains Yes, I'm. Why you asking? Did you known something new?
@markub3327 It would be great if you can share the implementation because I also tried something similar but I am unable to make this work - Or what you posted here worked well for you?
I also tried squaring the alpha and alpha next but still the results are poor...
@danbochman Hi, I don't have any else implementation, that is above. And the results I get better with discrete timestep by this implementation. Here is the DDIM sampler with continuous time which may be working well: https://keras.io/examples/generative/ddim/