denoising-diffusion-pytorch icon indicating copy to clipboard operation
denoising-diffusion-pytorch copied to clipboard

Running on CIFAR 10

Open DushyantSahoo opened this issue 3 years ago • 14 comments

Hi,

I am trying to train and sample using CIFAR 10 dataset. Below is the code for it.

from keras.datasets import mnist
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
import numpy as np
import tensorflow as tf

model = Unet(
    dim = 16,
    dim_mults = (1, 2, 4)
)

diffusion = GaussianDiffusion(
    model,
    image_size = 32,
    timesteps = 1000,   # number of steps
    loss_type = 'l1'    # L1 or L2
)

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = np.asarray(x_train)
x_train = x_train.astype(np.float16)
new_train = torch.from_numpy(np.swapaxes(x_train,1,3))
training_images = torch.randn(8, 3, 128, 128) # images are normalized from 0 to 1

trainer = Trainer(
    diffusion,
    new_train,
    train_batch_size = 128,
    train_lr = 1e-4,
    train_num_steps = 70000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.9999,                # exponential moving average decay
    amp = True                        # turn on mixed precision
)
trainer.train()

I modified Trainer such that it could take the dataset. The original Trainer had the below code

self.ds = Dataset(folder, self.image_size, augment_horizontal_flip = augment_horizontal_flip)
dl = DataLoader(self.ds, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())

which I modified to

my_dataset = TensorDataset(data) # create your datset
dl = DataLoader(data, batch_size = train_batch_size, shuffle = True, pin_memory = True, num_workers = cpu_count())
self.dl = cycle(dl)

The training error in the above case goes to inf after 20k iterations. If I stop before that and sample from it, the images are bunch of random colors. Is there any script which I can use to generate samples from CIFAR10?

Thank You

DushyantSahoo avatar Jul 08 '22 19:07 DushyantSahoo

Maybe you can convert the cifar10 data set to png format images and store them in a file, and then train according to the author's second method

LangdonYu avatar Aug 02 '22 06:08 LangdonYu

Maybe you can convert the cifar10 data set to png format images and store them in a file, and then train according to the author's second method

And how to solve the problem with the fact that there is a separate folder for each class? I unpacked the dataset into png, but when I try to teach the model I get an error: Value error: num_samples should be a positive integer, but got num_samples=0. This indicates that model can't read the dataset. I hoped that turning cifar10 into png will help with this problem but it didn't

HalcyonForest avatar Aug 02 '22 13:08 HalcyonForest

Hi, did you solve this problem?

Allencheng97 avatar Aug 11 '22 14:08 Allencheng97

Hi, did you solve this problem?

@Allencheng97

Yes, I forked this repo and changed torch.Dataset() to torch.CIFAR10() (line ~~ 721-719)

HalcyonForest avatar Aug 11 '22 15:08 HalcyonForest

Hi, did you solve this problem?

@Allencheng97

Yes, I forked this repo and changed torch.Dataset() to torch.CIFAR10() (line ~~ 721-719) Thanks!

Allencheng97 avatar Aug 11 '22 17:08 Allencheng97

I have found what caused it. You should set amp to False when training on cifar10. When I did this, the model can converge and generate normal pictures instead of a bunch of random colours.

greens007 avatar Aug 19 '22 03:08 greens007

I have found what caused it. You should set amp to False when training on cifar10. When I did this, the model can converge and generate normal pictures instead of a bunch of random colours.

Would it be possible to share your code? I'm having some issues getting my version to actually converge, despite also using cifar10. Thanks!

DevJake avatar Aug 24 '22 09:08 DevJake

I have found what caused it. You should set amp to False when training on cifar10. When I did this, the model can converge and generate normal pictures instead of a bunch of random colours.

Would it be possible to share your code? I'm having some issues getting my version to actually converge, despite also using cifar10. Thanks!

Hi, is your problem resolved? I am facing a similar issue

SilvesterYu avatar Sep 21 '22 05:09 SilvesterYu

My use case for this library was slightly different than its original purpose... My version is considerably modified. My version does resolve the issue relating to training on CIFAR-10, although how much value its modifications will be to you may vary.

You can check out my repository at DevJake/EEG-diffusion-pytorch. Let me know if it's of use!

DevJake avatar Sep 21 '22 06:09 DevJake

Hi, did you solve this problem?

@Allencheng97

Yes, I forked this repo and changed torch.Dataset() to torch.CIFAR10() (line ~~ 721-719)

Yes, your method is effective, I heard from my lab that converting images to png format for training is a bit accuracy-damaging

LangdonYu avatar Oct 12 '22 07:10 LangdonYu

I wonder why that is... I would've figured the subtle compression applied by JPG format would cause potential data loss and thus losses in performance. Equally, it might act as a form of very low-level image augmentation by adding in artefacts. Got anything more on your findings? I'd be interested to know how it was determined

DevJake avatar Oct 12 '22 07:10 DevJake

https://github.com/lucidrains/denoising-diffusion-pytorch/issues/57#issuecomment-1220196043

I face same problem with ffhq, fixed when setting amp=False. But why?@greens007

kaka45inablink avatar Mar 29 '23 06:03 kaka45inablink

@LangdonYu how should I modify it specifically, I did not find torch.Dataset() in 721 in the denoising_diffusion_pytorch.py file. Thanks!

baizhenzheng avatar Apr 24 '23 12:04 baizhenzheng

@LangdonYu how should I modify it specifically, I did not find torch.Dataset() in 721 in the denoising_diffusion_pytorch.py file. Thanks!

Did you solve this problem?

neilwen987 avatar Aug 08 '23 14:08 neilwen987