DALLE2-pytorch
DALLE2-pytorch copied to clipboard
ddim makes the generation worse
Hi, I met an issue that when I use ddim for the decoder sampling, the pics don't look good.
When I change the sample step to 1000, it comes to the following result.
Could I ask how to fix it?
The following is the ddim part of my code.
` def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False): device = self.betas.device
b = shape[0]
img = torch.randn(shape, device = device)
timesteps = 250
times = torch.linspace(0., 1000, steps = timesteps + 2)[:-1]
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:]))
print(time_pairs)
alphas = self.alphas_cumprod_prev
if not is_latent_diffusion:
lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
alpha = alphas[time]
alpha_next = alphas[time_next]
# print("alpha_next",alpha_next)
# print("alpha_next1",alpha_next1)
time_cond = torch.full((b,), time, device = device, dtype = torch.long)
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
if predict_x_start:
x_start = pred
pred_noise = self.predict_noise_from_start(img, t = time_cond, x0 = pred)
else:
x_start = self.predict_start_from_noise(img, t = time_cond, noise = pred)
pred_noise = pred
if clip_denoised:
s = 1.
# clip by threshold, depending on whether static or dynamic
x_start = x_start.clamp(-s, s) / s
c1 = 1 * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c2 = ((1 - alpha_next) - torch.square(c1)).sqrt()
noise = torch.randn_like(img) if time_next > 0 else 0.
img = x_start * alpha_next.sqrt() + \
c1 * noise + \
c2 * pred_noise
img = self.unnormalize_img(img)
return img`
can you tell me the dataset you used thanks
can you tell me the dataset you used thanks
pics from Internet
can you tell me the dataset you used thanks
pics from Internet
Can I see the code you trained