denoising-diffusion-pytorch icon indicating copy to clipboard operation
denoising-diffusion-pytorch copied to clipboard

Training on Celeba-hq

Open moonnnpie opened this issue 11 months ago • 5 comments

Thanks for your work.

I do training on celeba-hq dataset, and after 110k steps, I find that the images seem to have color problem, is there something wrong i need to do with datasets? 64a5ac5ea03fc669048ef68a3db224f

follows are my settings

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer

model = Unet( dim = 64, dim_mults = (1, 2, 4, 8) ).cuda()

diffusion = GaussianDiffusion( model, image_size = 256, timesteps = 1000, # number of steps sampling_timesteps = 250, # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) loss_type = 'l1' # L1 or L2 ).cuda()

trainer = Trainer( diffusion, '/mnt/shared/deepfake/CelebA-HQ/train', train_batch_size = 32, train_lr = 2e-5, train_num_steps = 7000000, # total training steps gradient_accumulate_every = 2, # gradient accumulation steps ema_decay = 0.995, # exponential moving average decay amp = False, # turn on mixed precision calculate_fid = True # whether to calculate fid during training )

moonnnpie avatar Mar 18 '24 09:03 moonnnpie

you can try 'amp = True'

Zhangzeyu7 avatar Mar 19 '24 08:03 Zhangzeyu7

you can try 'amp = True'

thanks but i tried and found that the images turns out total green

moonnnpie avatar Mar 19 '24 10:03 moonnnpie

I am also currently trying to get some reasonable results for the FFHQ dataset, and also want to try Celeba-HQ.

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
from torch.utils.data import DataLoader
from torchvision.datasets import VisionDataset
from torchvision.transforms.functional import pil_to_tensor
from PIL import Image

class FFHQDataset(VisionDataset):
    def __init__(self, root: str):
        super().__init__(root)

        self.fpaths = sorted(glob(root + '/**/*.png', recursive=True))
        assert len(self.fpaths) > 0, "File list is empty. Check the root."

    def __len__(self):
        return len(self.fpaths)

    def __getitem__(self, index: int):
        fpath = self.fpaths[index]
        img = Image.open(fpath).convert('RGB')
        # normalize to [0, 1] range
        img = pil_to_tensor(img) / 255.
        return img


model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4, 8, 16, 32),
    flash_attn = True
)

diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    timesteps = 1000,           # number of steps
    sampling_timesteps=500
)

dataset = FFHQDataset(root="/mnt/SSD2/nils/ocean_bench_exps/diffusion/data/ffhq/thumbnails128x128")
dataloader = DataLoader(dataset, batch_size = 32, shuffle = True, pin_memory = True, num_workers = 12)

trainer = Trainer(
    diffusion,
    'path/to/your/images',
    train_lr = 8e-5,
    train_num_steps = 50000,         # total training steps
    gradient_accumulate_every = 2,    # gradient accumulation steps
    ema_decay = 0.995,                # exponential moving average decay
    amp = False,
    num_samples=16,
    save_and_sample_every=10000,
    dl = dataloader,
)

trainer.train()

I gave the trainer a dataloader argument, because I wanted control over different dataloaders and their configurations, so effectively, just replaced the dataset and dl code block to just take the dl argument from the Trainer. The following are some samples, loss is around 0.02-0.03.

It was mentioned here that amp=False helps, but I have tried both and there is no significant change.

Screenshot from 2024-03-22 12-56-48

Overall I would also expect better results, so I am wondering if people have experience and suggestions?

Edit: Training for longer seems to improve results a bit (300,000 training steps) sample-30

nilsleh avatar Mar 22 '24 12:03 nilsleh

These are results on the CelebHQ datset:

sample-29

nilsleh avatar Apr 08 '24 09:04 nilsleh

you can try 'amp = True'

thanks but i tried and found that the images turns out total green

Have you solved this problem? I meet this problem recently.

szh404 avatar Jun 13 '24 12:06 szh404