ddpm-pytorch
ddpm-pytorch copied to clipboard
已知加噪过程时的噪声,反向去噪时无法去噪
up你好,我将正向加噪过程过程中使用到的高斯噪声保存了下来,在去噪的时候用到了这些噪声,但是发现最终得到的图像全是噪声点,请问一下这是咋回事啊,下面是我的代码,我是你在b站上的粉丝。`import numpy as np import torch from PIL import Image import os def preprocess_input(x): x /= 255 x -= 0.5 x /= 0.5 return x
def cvtColor(image): if len(np.shape(image)) == 3 and np.shape(image)[2] == 3: return image else: image = image.convert('RGB') return image
def postprocess_output(x): x *= 0.5 x += 0.5 x *= 255 return x
def extract(a, t, x_shape): b, *_ = t.shape out = a.gather(-1, t) return out.reshape(b, *((1,) * (len(x_shape) - 1))) def perturb_x(sqrt_alphas_cumprod, x, t, noise, sqrt_one_minus_alphas_cumprod): return ( extract(sqrt_alphas_cumprod, t, x.shape) * x + extract(sqrt_one_minus_alphas_cumprod, t, x.shape) * noise )
def remove_noise(remove_noise_coeff, noise, reciprocal_sqrt_alphas, x, t, use_ema=False): if use_ema: return ( (x - extract(remove_noise_coeff, t, x.shape) * noise) * extract(reciprocal_sqrt_alphas, t, x.shape) ) else: return ( (x - extract(remove_noise_coeff, t, x.shape) * noise) * extract(reciprocal_sqrt_alphas, t, x.shape) )
num_timesteps = 100 save_path = "tmp.jpg" if not os.path.exists("original_pic"): os.makedirs("original_pic") if not os.path.exists("after_pic"): os.makedirs("after_pic") image = Image.open("0_clean.png") image = cvtColor(image).resize([128, 128], Image.BICUBIC) image = np.array(image, dtype=np.float32) image = np.transpose(preprocess_input(image), (2, 0, 1)) x = torch.from_numpy(np.array(image, np.float32)) x = x[None,:,:,:] betas = torch.linspace(start=0.0001, end=0.02, steps=1000) alphas = 1 - betas alphas_cumprod = torch.cumprod(alphas,dim=0) sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod) reciprocal_sqrt_alphas = torch.sqrt(1 / alphas) remove_noise_coeff = betas / torch.sqrt(1 - alphas_cumprod) sigma = torch.sqrt(betas)
保留加噪过程中的epsilon,用于下个阶段的还原
epsilon_list = [] for t in range(num_timesteps): t = torch.tensor([t]) epsilon = torch.randn_like(x) epsilon_list.append(epsilon) x_t = perturb_x(sqrt_alphas_cumprod, x, t, epsilon, sqrt_one_minus_alphas_cumprod) tmp1 = x_t.clone() test_images = postprocess_output(tmp1[0].cpu().data.numpy().transpose(1, 2, 0)) Image.fromarray(np.uint8(test_images)).save(os.path.join("original_pic", str(t) + ".png"))
去噪过程随机采样的xt
x = x_t #x = torch.randn((1, 3, 128, 128)) for t in range(num_timesteps - 1, -1, -1): t_batch = torch.tensor([t]).repeat(1) x = remove_noise(remove_noise_coeff, epsilon_list[t], reciprocal_sqrt_alphas, x, t_batch) if t > 0: x += extract(sigma, t_batch, x.shape) * torch.randn_like(x) tmp = x.clone() test_images = postprocess_output(tmp[0].cpu().data.numpy().transpose(1, 2, 0)) Image.fromarray(np.uint8(test_images)).save(os.path.join("after_pic", str(t) + ".png"))
test_images = postprocess_output(x[0].cpu().data.numpy().transpose(1, 2, 0)) Image.fromarray(np.uint8(test_images)).save(save_path)`