denoising-diffusion-pytorch icon indicating copy to clipboard operation
denoising-diffusion-pytorch copied to clipboard

Sanity Check - Looking for a basic CIFAR10 hyperparameter set

Open samuelemarro opened this issue 7 months ago • 6 comments

I'm running the denoising_diffusion_pytorch.py script as-is on the CIFAR10 dataset, however the FID quickly plateaus to ~90, which is a far cry from both those reported in the DDIM/DDPM paper and even in other open issues (e.g. #326). Here are my hyperparameters:

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8),
    flash_attn = True,
    dropout=0.1
)

diffusion = GaussianDiffusion(
    model,
    image_size = 32,
    timesteps = 1000,
    sampling_timesteps = 250
)

trainer = Trainer(
    diffusion,
    './data/cifar10',
    train_batch_size = 32,
    train_lr = 8e-5,
    train_num_steps = 100000,
    gradient_accumulate_every = 2,
    ema_decay = 0.995,
    num_fid_samples=500,
    save_and_sample_every=1000,
    amp = False,
    calculate_fid = True
)

No matter how I tune it, I can't seem to beat ~70. Am I going crazy? I feel like there's something obvious I'm missing, but I can't see what.

samuelemarro avatar Jun 30 '24 13:06 samuelemarro