conditional-ddpm
conditional-ddpm copied to clipboard
A simple PyTorch implementation of conditional denoising diffusion probabilistic models (DDPM) on MNIST, Fashion-MNIST, and Sprite datasets
Conditional DDPM
Introduction
We implement a simple conditional form of Diffusion Model described in Denoising Diffusion Probabilistic Models, in PyTorch. Preparing this repository, we inspired by the course How Diffusion Models Work and the repository minDiffusion. While training, we use MNIST, FashionMNIST, and Sprite (see FrootsnVeggies and kyrise) datasets.
Setting Up the Environment
- Install Conda, if not already installed.
- Clone the repository
git clone https://github.com/byrkbrk/diffusion-model.git - In the directory
diffusion-model, for macos, run:
For linux or windows, run:conda env create -f diffusion-env_macos.yamlconda env create -f diffusion-env_linux_or_windows.yaml - Activate the environment:
conda activate diffusion-env
Training and Sampling
MNIST
To train the model on MNIST dataset from scratch,
python3 train.py --dataset-name mnist
In order to sample from our (pretrained) checkpoint:
python3 sample.py pretrained_mnist_checkpoint_49.pth --n-samples 400 --n-images-per-row 20
Results (jpeg and gif files) will be saved into generated-images directory, and are seen below where each two rows represents a class label (in total 20 rows and 10 classes).
Fashion-MNIST
To train the model from scratch on Fashion-MNIST dataset,
python3 train.py --dataset-name fashion_mnist
In order to sample from our (pretrained) checkpoint, run:
python3 sample.py pretrained_fashion_mnist_checkpoint_49.pth --n-samples 400 --n-images-per-row 20
Results (jpeg and gif files) will be saved into generated-images directory, and are seen below where each two rows represents a class label (in total 20 rows and 10 classes).
Sprite
To train the model from scratch on Sprite dataset:
python3 train.py --dataset-name sprite
In order to sample from our (pretrained) checkpoint, run:
python3 sample.py pretrained_sprite_checkpoint_49.pth --n-samples 225 --n-images-per-row 15
Results (jpeg and gif files) will be saved into generated-images directory, and are seen below where each three rows represents a class label (in total 15 rows and 5 classes).