denoising-diffusion-pytorch
denoising-diffusion-pytorch copied to clipboard
Sanity Check - Looking for a basic CIFAR10 hyperparameter set
I'm running the denoising_diffusion_pytorch.py
script as-is on the CIFAR10 dataset, however the FID quickly plateaus to ~90, which is a far cry from both those reported in the DDIM/DDPM paper and even in other open issues (e.g. #326). Here are my hyperparameters:
model = Unet(
dim = 64,
dim_mults = (1, 2, 4, 8),
flash_attn = True,
dropout=0.1
)
diffusion = GaussianDiffusion(
model,
image_size = 32,
timesteps = 1000,
sampling_timesteps = 250
)
trainer = Trainer(
diffusion,
'./data/cifar10',
train_batch_size = 32,
train_lr = 8e-5,
train_num_steps = 100000,
gradient_accumulate_every = 2,
ema_decay = 0.995,
num_fid_samples=500,
save_and_sample_every=1000,
amp = False,
calculate_fid = True
)
No matter how I tune it, I can't seem to beat ~70. Am I going crazy? I feel like there's something obvious I'm missing, but I can't see what.