diffusion-transformer
diffusion-transformer copied to clipboard
Implementation of Diffusion Transformer Model in Pytorch
Diffusion Transformer
Implementation of the Diffusion Transformer model in the paper:

See here for the official Pytorch implementation.
Dependencies
- Python 3.9
- Pytorch 2.1.1
Training Diffusion Transformer
Use --data_dir=<data_dir> to specify the dataset path.
python train.py --data_dir=./data/
Samples
Sample output from minDiT (39.89M parameters) on CIFAR-10:

Sample output from minDiT on CelebA:

Hparams setting
Adjust hyperparameters in the config.py file.
Implementation notes:
- minDiT is designed to offer reasonable performance using a single GPU (RTX 3080 TI).
- minDiT largely follows the original DiT model.
- DiT Block with adaLN-Zero.
- Diffusion Transformer with Linformer attention.
- EDM sampler.
- FID evaluation.
todo
- Add Classifier-Free Diffusion Guidance and conditional pipeline.
- Add Latent Diffusion and Autoencoder training.
- Add generate.py file.
Licence
MIT