diffusion-transformer-keras
diffusion-transformer-keras copied to clipboard
Implementation of Latent Diffusion Transformer Model in Tensorflow / Keras
Diffusion Transformer
Implementation of the Diffusion Transformer model in the paper:

See here for the official Pytorch implementation.
Dependencies
- Python 3.8
- TensorFlow 2.12
Training AutoencoderKL
Use --train_file_pattern=<file_pattern> and --test_file_pattern=<file_pattern> to specify the train and test dataset path.
python ae_train.py --train_file_pattern='./train_dataset_path/*.png' --test_file_pattern='./test_dataset_path/*.png'
Training Diffusion Transformer
Use --file_pattern=<file_pattern> to specify the dataset path.
python ldt_train.py --file_pattern='./dataset_path/*.png'
*Training DiT requires the pretrained AutoencoderKL. Use ae_dir and ae_name to specify the AutoencoderKL path in the ldt_config.py file.
Sampling
Use --model_dir=<model_dir> and --ldt_name=<ldt_name> to specify the pre-trained model. For example:
python sample.py --model_dir=ldt --ldt_name=model_1 --diffusion_steps=40
Hparams setting
Adjust hyperparameters in the ae_config.py and ldt_config.py files.
Implementation notes:
- LDT is designed to offer reasonable performance using a single GPU (RTX 3080 TI).
- LDT largely follows the original DiT model.
- DiT Block with adaLN-Zero.
- Diffusion Transformer with Linformer attention.
- Cosine schedule.
- DDIM sampler.
- FID evaluation.
- AutoencoderKL with PatchGAN discriminator and hinge loss.
- This implementation uses code from the beresandras repo. Under MIT Licence.
Samples
Curated samples from FFHQ

Licence
MIT