mcvd-pytorch
mcvd-pytorch copied to clipboard
Question about DDPM and DDIM sampling.
Hi, thanks for sharing your excellent work!
I just walked through the code base and noticed that during sampling you used timestamp t from 0 to 999 (see here. I think in the reversed pass, we should start from 999 till 0. I'm a little confused about this.
Another question is, what does the denoise
option mean for the last sampling step? please check here.
These two questions can be raised either for the DDPM or DDIM sampler. Really appreciate your explanation.
您好,我的看法是这样的。
作者使用的是基于朗之万动力学NCSN扩散模型。原论文在设置参数的时候,开始参数大于结束参数。betas按列表顺序由大到小。在降噪过程中,betas应该是由大到小,列表索引应该是0-999。X_T是真实图像。
https://github.com/voletiv/mcvd-pytorch/blob/451da2eb635bad50da6a7c03b443a34c6eb08b3a/configs/kth64_big.yml#L81-L82
https://github.com/voletiv/mcvd-pytorch/blob/226a3fd1601a6fde1c59d01fb22f82cb37b3b8c4/models/init.py#L24-L26
https://github.com/voletiv/mcvd-pytorch/blob/226a3fd1601a6fde1c59d01fb22f82cb37b3b8c4/models/better/ncsnpp_more.py#L736-L739
https://github.com/voletiv/mcvd-pytorch/blob/451da2eb635bad50da6a7c03b443a34c6eb08b3a/models/init.py#L267
作为对比: 在DDPM原论文中,开始参数小于结束参数。betas按列表顺序由小到大。降噪过程时间步从列表索引999到0。X_0是真实图像。
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/scripts/run_celebahq.py#L132-L137
def train(
exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='celebahq256',
optimizer='adam', total_bs=64, grad_clip=1., lr=0.00002, warmup=5000,
num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred',
dropout=0.0, randflip=1, block_size=1,
tfds_data_dir='tensorflow_datasets', log_dir='logs'
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L26-L27
elif beta_schedule == 'linear':
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L205-L217
i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32)
img_0 = noise_fn(shape=shape, dtype=tf.float32)
_, img_final = tf.while_loop(
cond=lambda i_, _: tf.greater_equal(i_, 0),
body=lambda i_, img_: [
i_ - 1,
self.p_sample(
denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn, return_pred_xstart=False)
],
loop_vars=[i_0, img_0],
shape_invariants=[i_0.shape, img_0.shape],
back_prop=False
)
LDM latent diffusion类似。 LDM DDPM开始参数小于结束参数。betas按列表顺序由小到大。降噪过程时间步从列表索引999到0。X_0是真实图像。
https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/configs/latent-diffusion/cin256-v2.yaml#L5-L6
linear_start: 0.0015
linear_end: 0.0195
https://github.com/CompVis/latent-diffusion/blob/171cf29fb54afe048b03ec73da8abb9d102d0614/ldm/modules/diffusionmodules/util.py#L22-L25
if schedule == "linear":
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/models/diffusion/ddpm.py#L258-L260
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
clip_denoised=self.clip_denoised)
LDM DDIM
https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/models/diffusion/ddim.py#L133-L160
time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
img, pred_x0 = outs
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)