consistency-models icon indicating copy to clipboard operation
consistency-models copied to clipboard

Implementation of Consistency Models (Song et al 2023) for few-step image generation in Jax.

Consistency Models

Description

Implementation of Consistency Models, a class of diffusion-adjacent models introduced in Song et al (2023), in Jax.

When used as standalone generative models, consistency models achieve state of the art performance in one- and few-step generation, outperforming existing techniques distilling diffusion models.

A minified, self-contained implementation of the discrete-time version of the model trained on MNIST is in the notebook mnist-example.ipynb.

Implementation Notes

  • This repo uses a simple MLP-Mixer as the backbone for the consistency function.
  • I only implement what the paper calls consistency training (CT), where the model is trained from scratch, rather than consistency distillation (CD), where the model is distilled from a pre-trained diffusion model.
  • The continuous-time objective is implemented, but I have not gotten this to work well for consistency training. In the paper, the authors note, "For consistency training (CT), we find it important to initialize consistency models from a pre-trained EDM model in order to stabilize training when using continuous-time objectives. We hypothesize that this is caused by the large variance in our continuous-time loss functions", so this may not be surprising.

Usage

Train and logging (optional, through wandb):

python train.py --config ./config/cifar10.py

Samples

Samples with 5 (left) and 2 (right) step generation, MNIST trained over 100k steps with a batch size of 512 in consistency-mnist.ipynb.

5- and 2-step samples for MNIST.

Samples with 5 (left) and 2 (right) step generation, CIFAR-10 trained over ~900k steps with a batch size of 512. These don't look... great, likely because of the choice of MLP-Mixer architecture backbone.

5- and 2-step samples for MNIST.