ddpm-torch icon indicating copy to clipboard operation
ddpm-torch copied to clipboard

Unofficial PyTorch Implementation of Denoising Diffusion Probabilistic Models (DDPM)

banner


PyTorch Implementation of Denoising Diffusion Probabilistic Models [paper] [official repo]

Features

  • [x] Original DDPM[^1] training & sampling
  • [x] DDIM[^2] sampler
  • [x] Standard evaluation metrics
    • [x] Fréchet Inception Distance[^3] (FID)
    • [x] Precision & Recall[^4]
  • [x] Distributed Data Parallel[^5] (DDP) multi-GPU training

Requirements

  • torch>=1.12.0
  • torchvision>=1.13.0
  • scipy>=1.7.3

Code usage

Toy data Real-world data 
Training Training Generation Evaluation
Expand

usage: train_toy.py [-h] [--dataset {gaussian8,gaussian25,swissroll}]      
                    [--size SIZE] [--root ROOT] [--epochs EPOCHS] [--lr LR]
                    [--beta1 BETA1] [--beta2 BETA2] [--lr-warmup LR_WARMUP]
                    [--batch-size BATCH_SIZE] [--timesteps TIMESTEPS]      
                    [--beta-schedule {quad,linear,warmup10,warmup50,jsd}]  
                    [--beta-start BETA_START] [--beta-end BETA_END]        
                    [--model-mean-type {mean,x_0,eps}]                     
                    [--model-var-type {learned,fixed-small,fixed-large}]   
                    [--loss-type {kl,mse}] [--image-dir IMAGE_DIR]         
                    [--chkpt-dir CHKPT_DIR] [--chkpt-intv CHKPT_INTV]      
                    [--eval-intv EVAL_INTV] [--seed SEED] [--resume]       
                    [--device DEVICE] [--mid-features MID_FEATURES]        
                    [--num-temporal-layers NUM_TEMPORAL_LAYERS]            
optional arguments:                                                        
  -h, --help            show this help message and exit                    
  --dataset {gaussian8,gaussian25,swissroll}                               
  --size SIZE                                                              
  --root ROOT           root directory of datasets                         
  --epochs EPOCHS       total number of training epochs                    
  --lr LR               learning rate                                      
  --beta1 BETA1         beta_1 in Adam                                     
  --beta2 BETA2         beta_2 in Adam                                     
  --lr-warmup LR_WARMUP                                                    
                        number of warming-up epochs                        
  --batch-size BATCH_SIZE                                                  
  --timesteps TIMESTEPS                                                    
                        number of diffusion steps                          
  --beta-schedule {quad,linear,warmup10,warmup50,jsd}                      
  --beta-start BETA_START                                                  
  --beta-end BETA_END                                                      
  --model-mean-type {mean,x_0,eps}
  --model-var-type {learned,fixed-small,fixed-large}
  --loss-type {kl,mse}
  --image-dir IMAGE_DIR
  --chkpt-dir CHKPT_DIR
  --chkpt-intv CHKPT_INTV
                        frequency of saving a checkpoint
  --eval-intv EVAL_INTV
  --seed SEED           random seed
  --resume              to resume training from a checkpoint
  --device DEVICE
  --mid-features MID_FEATURES
  --num-temporal-layers NUM_TEMPORAL_LAYERS
                
Expand

usage: train.py [-h] [--dataset {mnist,cifar10,celeba,celebahq}] [--root ROOT]
                [--epochs EPOCHS] [--lr LR] [--beta1 BETA1] [--beta2 BETA2]   
                [--batch-size BATCH_SIZE] [--num-accum NUM_ACCUM]
                [--block-size BLOCK_SIZE] [--timesteps TIMESTEPS]
                [--beta-schedule {quad,linear,warmup10,warmup50,jsd}]
                [--beta-start BETA_START] [--beta-end BETA_END]
                [--model-mean-type {mean,x_0,eps}]
                [--model-var-type {learned,fixed-small,fixed-large}]
                [--loss-type {kl,mse}] [--num-workers NUM_WORKERS]
                [--train-device TRAIN_DEVICE] [--eval-device EVAL_DEVICE]
                [--image-dir IMAGE_DIR] [--image-intv IMAGE_INTV]
                [--num-save-images NUM_SAVE_IMAGES] [--config-dir CONFIG_DIR]
                [--chkpt-dir CHKPT_DIR] [--chkpt-name CHKPT_NAME]
                [--chkpt-intv CHKPT_INTV] [--seed SEED] [--resume]
                [--chkpt-path CHKPT_PATH] [--eval] [--use-ema]
                [--ema-decay EMA_DECAY] [--distributed] [--rigid-launch]
                [--num-gpus NUM_GPUS] [--dry-run]
optional arguments:
  -h, --help            show this help message and exit
  --dataset {mnist,cifar10,celeba,celebahq}
  --root ROOT           root directory of datasets
  --epochs EPOCHS       total number of training epochs
  --lr LR               learning rate
  --beta1 BETA1         beta_1 in Adam
  --beta2 BETA2         beta_2 in Adam
  --batch-size BATCH_SIZE
  --num-accum NUM_ACCUM
                        number of mini-batches before an update
  --block-size BLOCK_SIZE
                        block size used for pixel shuffle
  --timesteps TIMESTEPS
                        number of diffusion steps
  --beta-schedule {quad,linear,warmup10,warmup50,jsd}
  --beta-start BETA_START
  --beta-end BETA_END
  --model-mean-type {mean,x_0,eps}
  --model-var-type {learned,fixed-small,fixed-large}
  --loss-type {kl,mse}
  --chkpt-path CHKPT_PATH
                        checkpoint path used to resume training
  --eval                whether to evaluate fid during training
  --use-ema             whether to use exponential moving average
  --ema-decay EMA_DECAY
                        decay factor of ema
  --distributed         whether to use distributed training
  --rigid-launch        whether to use torch multiprocessing spawn
  --num-gpus NUM_GPUS   number of gpus for distributed training
  --dry-run             test-run till the first model update completes
            	
Expand

usage: generate.py [-h] [--dataset {mnist,cifar10,celeba,celebahq}]
                   [--batch-size BATCH_SIZE] [--total-size TOTAL_SIZE]
                   [--config-dir CONFIG_DIR] [--chkpt-dir CHKPT_DIR]
                   [--chkpt-path CHKPT_PATH] [--save-dir SAVE_DIR]
                   [--device DEVICE] [--use-ema] [--use-ddim] [--eta ETA]
                   [--skip-schedule SKIP_SCHEDULE] [--subseq-size SUBSEQ_SIZE]
                   [--suffix SUFFIX] [--max-workers MAX_WORKERS]
                   [--num-gpus NUM_GPUS]
optional arguments:
  -h, --help            show this help message and exit
  --dataset {mnist,cifar10,celeba,celebahq}
  --batch-size BATCH_SIZE
  --total-size TOTAL_SIZE
  --config-dir CONFIG_DIR
  --chkpt-dir CHKPT_DIR
  --chkpt-path CHKPT_PATH
  --save-dir SAVE_DIR
  --device DEVICE
  --use-ema
  --use-ddim
  --eta ETA
  --skip-schedule SKIP_SCHEDULE
  --subseq-size SUBSEQ_SIZE
  --suffix SUFFIX
  --max-workers MAX_WORKERS
  --num-gpus NUM_GPUS
			
Expand

usage: eval.py [-h] [--root ROOT] [--dataset {mnist,cifar10,celeba,celebahq}]
               [--model-device MODEL_DEVICE] [--eval-device EVAL_DEVICE]
               [--eval-batch-size EVAL_BATCH_SIZE]
               [--eval-total-size EVAL_TOTAL_SIZE] [--num-workers NUM_WORKERS]
               [--nhood-size NHOOD_SIZE] [--row-batch-size ROW_BATCH_SIZE]
               [--col-batch-size COL_BATCH_SIZE] [--device DEVICE]
               [--eval-dir EVAL_DIR] [--precomputed-dir PRECOMPUTED_DIR]
               [--metrics METRICS [METRICS ...]] [--seed SEED]
               [--folder-name FOLDER_NAME]
optional arguments:
  -h, --help            show this help message and exit
  --root ROOT
  --dataset {mnist,cifar10,celeba,celebahq}
  --model-device MODEL_DEVICE
  --eval-device EVAL_DEVICE
  --eval-batch-size EVAL_BATCH_SIZE
  --eval-total-size EVAL_TOTAL_SIZE
  --num-workers NUM_WORKERS
  --nhood-size NHOOD_SIZE
  --row-batch-size ROW_BATCH_SIZE
  --col-batch-size COL_BATCH_SIZE
  --device DEVICE
  --eval-dir EVAL_DIR
  --precomputed-dir PRECOMPUTED_DIR
  --metrics METRICS [METRICS ...]
  --seed SEED
  --folder-name FOLDER_NAME
			

Examples

  • Train a 25-Gaussian toy model with single GPU (device id: 0) for a total of 100 epochs

    python train_toy.py --dataset gaussian25 --device cuda:0 --epochs 100
    
  • Train CIFAR-10 model with single GPU (device id: 0) for a total of 50 epochs

    python train.py --dataset cifar10 --train-device cuda:0 --epochs 50
    

(You can always use dry-run for testing/tuning purpose.)

  • Train a CelebA model with an effective batch size of 64 x 2 x 4 = 128 on a four-card machine (single node) using shared file-system initialization

    python train.py --dataset celeba --num-accum 2 --num-gpus 4 --distributed --rigid-launch
    
    • num-accum 2: accumulate gradients for 2 mini-batches
    • num-gpus: number of GPU(s) to use for training, i.e. WORLD_SIZE of the process group
    • distributed: enable multi-gpu DDP training
    • rigid-run: use shared-file system initialization and torch.multiprocessing
  • (Recommended) Train a CelebA model with an effective batch-size of 64 x 1 x 2 = 128 using only two GPUs with torchrun Elastic Launch[^6] (TCP initialization)

    export CUDA_VISIBLE_DEVICES=0,1&&torchrun --standalone --nproc_per_node 2 --rdzv_backend c10d train.py --dataset celeba --distributed
    
  • Generate 50,000 samples (128 per mini-batch) of the checkpoint located at ./chkpts/cifar10/cifar10_2040.pt in parallel using 4 GPUs and DDIM sampler. The results are stored in ./images/eval/cifar10/cifar10_2040_ddim

    ython generate.py --dataset cifar10 --chkpt-path ./chkpts/cifar10/cifar10_2040.pt --use-ddim --skip-schedule quadratic --subseq-size 100 --suffix _ddim --num-gpus 4
    
    • use-ddim: use DDIM
    • skip-schedule quadratic: use the quadratic schedule
    • subseq-size: length of sub-sequence, i.e. DDIM timesteps
    • suffix: suffix string to the dataset name in the folder name
    • num-gpus: number of GPU(s) to use for generation
  • Evaluate FID, Precision/Recall of generated samples in ./images/eval/cifar10_2040

    ython eval.py --dataset cifar10 --sample-folder ./images/eval/cifar10/cifar10_2040
    

Experiment results

Toy data

Dataset 8 Gaussian 25 Gaussian Swiss Roll
True gaussian8_true_thumbnail gaussian25_true_thumbnail swissroll_true_thumbnail
Generated gaussian8_true_thumbnail gaussian25_true_thumbnail swissroll_true_thumbnail
Training process (animated)

Dataset 8 Gaussian 25 Gaussian Swiss Roll
Generated gaussian8_train_thumbnail gaussian25_train_thumbnail swissroll_train_thumbnail

Real-world data

Table of evaluated metrics

Dataset FID (↓) Precision (↑) Recall (↑) Training steps Training loss Checkpoint
CIFAR-10 9.162 0.691 0.473 46.8k 0.0295 -
|__ 5.778 0.697 0.516 93.6k 0.0293 -
|__ 4.083 0.705 0.539 187.2k 0.0291 -
|__ 3.31 0.722 0.551 421.2k 0.0284 -
|__ 3.188 0.739 0.544 795.6k 0.0277 [Link]
CelebA 4.806 0.772 0.484 189.8k 0.0155 -
|__ 3.797 0.764 0.511 379.7k 0.0152 -
|__ 2.995 0.760 0.540 949.2k 0.0148 [Link]
CelebA-HQ 19.742 0.683 0.256 56.2k 0.0105 -
|__ 11.971 0.705 0.364 224.6k 0.0097 -
|__ 8.851 0.768 0.376 393.1k 0.0098 -
|__ 8.91 0.800 0.357 561.6k 0.0097 [Link]

Dataset CIFAR-10 CelebA CelebA-HQ
Generated images cifar10_gen celeba_gen_thumbnail celebahq_gen_thumbnail
Denoising process (animated)

Dataset CIFAR-10 CelebA CelebA-HQ
Generated images cifar10_denoise celeba_denoise_thumbnail celebahq_denoise_thumbnail

Related repositories

References

[^1]: Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Advances in Neural Information Processing Systems 33 (2020): 6840-6851. [^2]: Song, Jiaming, Chenlin Meng, and Stefano Ermon. "Denoising Diffusion Implicit Models." International Conference on Learning Representations. 2020. [^3]: Heusel, Martin, et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium." Advances in neural information processing systems 30 (2017). [^4]: Kynkäänniemi, Tuomas, et al. "Improved precision and recall metric for assessing generative models." Advances in Neural Information Processing Systems 32 (2019). [^5]: DistributedDataParallel - PyTorch 1.12 Documentation, https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html. [^6]: Torchrun (Elastic Launch) - PyTorch 1.12 Documentation*, https://pytorch.org/docs/stable/elastic/run.html.