denoising-diffusion-pytorch icon indicating copy to clipboard operation
denoising-diffusion-pytorch copied to clipboard

How to generate images after training. Hope it could help you guys.

Open CooperLuo32 opened this issue 1 year ago • 1 comments

` from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer

model = Unet( dim = 64, dim_mults = (1, 2, 4, 8), flash_attn = True )

diffusion = GaussianDiffusion( model, image_size = 128, timesteps = 1000, # number of steps sampling_timesteps = 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) )

trainer = Trainer( diffusion, 'path/to/your/images', train_batch_size = 32, train_lr = 8e-5, train_num_steps = 700000, # total training steps gradient_accumulate_every = 2, # gradient accumulation steps ema_decay = 0.995, # exponential moving average decay amp = True, # turn on mixed precision calculate_fid = True # whether to calculate fid during training )

trainer.load(4) # load model-4.pt # load the checkpoint

sampled_images = diffusion.sample(batch_size=8)

samples_root = r"./samples" os.makedirs(samples_root , exist_ok=True) len_samples = len(os.listdir(samples_root))

for i in range(sampled_images.size(0)):

current_image_tensor = sampled_images [i]
current_image = Image.fromarray((current_image_tensor.cpu().permute(1, 2, 0).numpy() * 255).astype('uint8'))
file_name = f"output__image_{i + len_samples}.png"
current_image.save(os.path.join(os.getcwd(),"samples/" + file_name))

print("all samples are save in folder") `

that's the code I write for generating images after you got the best checkpoint.Hope it could help if it do help,leave your comment and let me know, thank you

CooperLuo32 avatar Sep 25 '23 02:09 CooperLuo32

@CooperLuo32 thanks for sharing. may i check for the input to the model , is it random noises as defined by "Image.fromarray((current_image_tensor.cpu().permute(1, 2, 0).numpy() * 255).astype('uint8'))"

also is it a need to have the Trainer class to load the checkpoints? As i don't wish to train further but to use the checkpoint to perform inference

for "sampled_images = diffusion.sample(batch_size=8)", may i know what does the sample refers to ?

lchunleo avatar Nov 15 '23 08:11 lchunleo