v-diffusion-torch
v-diffusion-torch copied to clipboard
PyTorch Implementation of V-objective Diffusion Probabilistic Models with Classifier-free Guidance
PyTorch Implementation of V-objective Diffusion Probabilistic Model (VDPM) and more
Key features
- [x] improved UNet design (conditioning, resampling, etc.) [^1]
- [x] continuous-time training on log-SNR schedule [^2]
- [x] DDIM sampler [^3]
- [x] MSE loss reweighting (constant, SNR, truncated-SNR) [^4]
- [x] velocity prediction [^4]
- [x] classifier-free guidance [^5]
- [x] others:
- [x] distributed data parallel (multi-gpu training)
- [x] gradient accumulation
- [x] FID/Precision/Recall evaluation
Basic usage
expand
usage: train.py [-h] [--dataset {mnist,cifar10,celeba}] [--root ROOT]
[--epochs EPOCHS] [--lr LR] [--beta1 BETA1] [--beta2 BETA2]
[--weight-decay WEIGHT_DECAY] [--batch-size BATCH_SIZE]
[--num-accum NUM_ACCUM] [--train-timesteps TRAIN_TIMESTEPS]
[--sample-timesteps SAMPLE_TIMESTEPS]
[--logsnr-schedule {linear,sigmoid,cosine,legacy}]
[--logsnr-max LOGSNR_MAX] [--logsnr-min LOGSNR_MIN]
[--model-out-type {x_0,eps,both,v}]
[--model-var-type {fixed_small,fixed_large,fixed_medium}]
[--reweight-type {constant,snr,truncated_snr,alpha2}]
[--loss-type {kl,mse}] [--intp-frac INTP_FRAC] [--use-cfg]
[--w-guide W_GUIDE] [--p-uncond P_UNCOND]
[--num-workers NUM_WORKERS] [--train-device TRAIN_DEVICE]
[--eval-device EVAL_DEVICE] [--image-dir IMAGE_DIR]
[--image-intv IMAGE_INTV] [--num-save-images NUM_SAVE_IMAGES]
[--sample-bsz SAMPLE_BSZ] [--config-dir CONFIG_DIR]
[--ckpt-dir CHKPT_DIR] [--ckpt-name CHKPT_NAME]
[--ckpt-intv CHKPT_INTV] [--seed SEED] [--resume] [--eval]
[--use-ema] [--use-ddim] [--ema-decay EMA_DECAY]
[--distributed]
optional arguments:
-h, --help show this help message and exit
--dataset {mnist,cifar10,celeba}
--root ROOT root directory of datasets
--epochs EPOCHS total number of training epochs
--lr LR learning rate
--beta1 BETA1 beta_1 in Adam
--beta2 BETA2 beta_2 in Adamffusion-torch> ^C
--weight-decay WEIGHT_DECAYects\v-diffusion-torch> ^C
decoupled weight_decay factor in Adamrain.py --help
--batch-size BATCH_SIZE
--num-accum NUM_ACCUM
number of batches before weight update, a.k.a.
gradient accumulation
--train-timesteps TRAIN_TIMESTEPS
number of diffusion steps for training (0 indicates
continuous training)
--sample-timesteps SAMPLE_TIMESTEPS
number of diffusion steps for sampling
--logsnr-schedule {linear,sigmoid,cosine,legacy}
--logsnr-max LOGSNR_MAX
--logsnr-min LOGSNR_MIN
--model-out-type {x_0,eps,both,v}
--model-var-type {fixed_small,fixed_large,fixed_medium}
--reweight-type {constant,snr,truncated_snr,alpha2}
--loss-type {kl,mse}
--intp-frac INTP_FRAC
--use-cfg whether to use classifier-free guidance
--w-guide W_GUIDE classifier-free guidance strength
--p-uncond P_UNCOND probability of unconditional training
--num-workers NUM_WORKERS
number of workers for data loading
--train-device TRAIN_DEVICE
--eval-device EVAL_DEVICE
--image-dir IMAGE_DIR
--image-intv IMAGE_INTV
--num-save-images NUM_SAVE_IMAGES
number of images to generate & save
--sample-bsz SAMPLE_BSZ
batch size for sampling
--config-dir CONFIG_DIR
--ckpt-dir CHKPT_DIR
--ckpt-name CHKPT_NAME
--ckpt-intv CHKPT_INTV
frequency of saving a checkpoint
--seed SEED random seed
--resume to resume training from a checkpoint
--eval whether to evaluate fid during training
Examples
# train cifar10 with one gpu
python train.py --dataset cifar10 --use-ema --use-ddim --num-save-images 80 --use-cfg --epochs 600 --ckpt-intv 120 --image-intv 10
# train cifar10 with two gpus
python -m torch.distributed.run --standalone --nproc_per_node 2 --rdzv_backend c10d train.py --dataset cifar10 --use-ema --use-ddim --num-save-images 80 --use-cfg --epochs 600 --ckpt-intv 120 --image-intv10 --distributed
# train celeba with one gpu with effective batch_size 128
python train.py --dataset celeba --use-ema --use-ddim --num-save-images 64 --use-cfg --epochs 240 --ckpt-intv 120 --image-intv 10 --num-accum 8 --sample-bsz 32
# train celebA with two gpus
python -m torch.distributed.run --standalone --nproc_per_node 2 --rdzv_backend c10d train.py --dataset celeba --use-ema --use-ddim --num-save-images 64 --use-cfg --epochs 240 --ckpt-intv 120 --image-intv 10 --distributed --num-accum 4 --sample-bsz 32
Conditional generation
CIFAR-10
guidance strength | class | images |
---|---|---|
w=0 FID:2.58 IS:9.76 |
airplanes | ![]() |
cars | ||
birds | ||
cats | ||
deer | ||
dogs | ||
frogs | ||
horses | ||
ships | ||
trucks | ||
w=0.1 FID:3.12 IS:10.01 |
airplanes | ![]() |
cars | ||
birds | ||
cats | ||
deer | ||
dogs | ||
frogs | ||
horses | ||
ships | ||
trucks | ||
w=1 FID:21.35 IS:9.92 |
airplanes | ![]() |
cars | ||
birds | ||
cats | ||
deer | ||
dogs | ||
frogs | ||
horses | ||
ships | ||
trucks |
CelebA
guidance strength | tag | Black_Hair | Blond_Hair | Brown_Hair | Gray_Hair |
---|---|---|---|---|---|
w=0 | Receding_Hairline | ![]() |
|||
Straight_Hair | |||||
Wavy_Hair | |||||
Bald | |||||
Bangs | |||||
w=1 | Receding_Hairline | ![]() |
|||
Straight_Hair | |||||
Wavy_Hair | |||||
Bald | |||||
Bangs | |||||
w=3 | Receding_Hairline | ![]() |
|||
Straight_Hair | |||||
Wavy_Hair | |||||
Bald | |||||
Bangs |
More variants (animated)
guidance strength | tag | Black_Hair | Blond_Hair | Brown_Hair | Gray_Hair |
---|---|---|---|---|---|
w=0 | Receding_Hairline | ![]() |
|||
Straight_Hair | |||||
Wavy_Hair | |||||
Bald | |||||
Bangs | |||||
w=1 | Receding_Hairline | ![]() |
|||
Straight_Hair | |||||
Wavy_Hair | |||||
Bald | |||||
Bangs | |||||
w=3 | Receding_Hairline | ![]() |
|||
Straight_Hair | |||||
Wavy_Hair | |||||
Bald | |||||
Bangs |
Acknowledgement
The development of this codebase is largely based on the official JAX implementation open-sourced by Google Research and my previous PyTorch implementation of DDPM, which are available at [google-research/diffusion_distillation] and [tqch/ddpm-torch] respectively.
References
[^1]: Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Advances in Neural Information Processing Systems 33 (2020): 6840-6851. [^2]: Kingma, Diederik, Tim Salimans, Ben Poole, and Jonathan Ho. "Variational diffusion models." Advances in neural information processing systems 34 (2021): 21696-21707. [^3]: Song, Jiaming, Chenlin Meng, and Stefano Ermon. "Denoising diffusion implicit models." arXiv preprint arXiv:2010.02502 (2020). [^4]: Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." arXiv preprint arXiv:2202.00512 (2022). [^5]: Ho, Jonathan, Tim Salimans. ‘Classifier-Free Diffusion Guidance’. NeurIPS 2021 Workshop on Deep Generative Models and Downstream Applications, 2021. https://openreview.net/forum?id=qw8AKxfYbI.