bsi icon indicating copy to clipboard operation
bsi copied to clipboard

Generative Modeling with Bayesian Sample Inference

Bayesian Sample Inference

Marten Lienen, Marcel Kollovieh, Stephan Günnemann

https://arxiv.org/abs/2502.07580

Getting Started

We provide an educational implementation for interactive exploration of the model in getting-started.ipynb. The notebook is self-contained, so you can download the file and directly run it on your own computer or start it on Google Colab.

To use BSI with your own model architecture, we recommend that you copy the self-contained bsi.py module into your project and you are good to go. The following code snippet shows you how to use the module with your own model and training code.

import torch
from torch import nn
from bsi import BSI, Discretization

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = nn.Conv2d(in_channels=4, out_channels=3, kernel_size=3, padding=1)
        
    def forward(self, mu, t):
        t = torch.movedim(t.expand((1, *mu.shape[-2:], len(t))), -1, 0)
        return self.layer(torch.cat((mu, t), dim=-3))

# Use your own model here! Check out our DiT and UNet implementations as a
# starting point.
model = Model()
bsi = BSI(
    model, data_shape=(3, 32, 32), lambda_0=1e-2, alpha_M=1e6, alpha_R=2e6,
    k=128, preconditioning="edm", discretization=Discretization.image_8bit())

from torchvision.datasets import CIFAR10
from torchvision.transforms import v2
from torch.utils.data import DataLoader
transforms = v2.Compose(
    [v2.ToImage(), v2.ToDtype(dtype=torch.float32, scale=True), v2.Normalize(mean=[0.5], std=[0.5])])
data = CIFAR10("data/cifar10", download=True, transform=transforms)

x, _ = next(iter(DataLoader(data, batch_size=32)))
loss = bsi.train_loss(x)
print(f"Training loss: {loss.mean():.5f}")

elbo, bpd, l_recon, l_measure = bsi.elbo(x, n_recon_samples=1, n_measure_samples=10)
print(f"Bits per dimension: {bpd.mean():.5f}")

from torchvision.utils import make_grid
import matplotlib.pyplot as plt
samples = bsi.sample(n_samples=4**2)
img_grid = make_grid(samples, nrow=4, normalize=True, value_range=(-1, 1))
plt.imshow(torch.movedim(img_grid, 0, -1))
plt.show()

Installation

If you want to run our code, start by setting up the python environment. We use pixi to easily set up reproducible environments based on conda packages. Install it with curl -fsSL https://pixi.sh/install.sh | bash and then run

# Clone the repository
git clone https://github.com/martenlienen/bsi.git

# Change into the repository
cd bsi

# Install and activate the environment
pixi shell

Training

Start a training by running train.py with the your settings, for example

./train.py data=cifar10

We use hydra for configuration, so you can overwrite all settings from the command line, e.g. the dataset with data=cifar10 as above. Explore all options in the config directory, e.g. with ./train.py trainer.devices=4 trainer.precision=bfloat16 you can train on 4 GPUs in 16-bit bfloat precision.

The cifar10 data module will download the dataset for you, but for imagenet32 and imagenet64 you have to download the 32x32 and 64x64 versions yourself from image-net.org in npz format. Unpack the archives into data/imagenet32/data and data/imagenet64/data respectively and then run ./train.py data=imagenet32 and ./train.py data=imagenet64 to preprocess them into hdf5 files.

You can re-create our training on, for example, the CIFAR10 dataset with the settings from the VDM paper with

./train.py experiment=cifar10-vdm

Use experiment=imagenet32-dit and experiment=imagenet64-dit for our diffusion transformer configurations on ImageNet.

To submit runs to a slurm cluster, use the slurm launcher config, e.g.

./train.py -m hydra/launcher=slurm hydra.launcher.partition=my-gpu-partition data=imagenet32

Fine-Tuning

To resume training from a checkpoint, pass a .ckpt file:

./train.py from_ckpt=path/to/file.ckpt task.n_steps=128 some.other_overrides=true

Citation

If you build upon this work, please cite our paper as follows.

@article{lienen2024bsi,
  title={Generative Modeling with Bayesian Sample Inference},
  author={Lienen, Marten and Kollovieh, Marcel and G{\"u}nnemann, Stephan},
  year={2025},
  eprint={2502.07580},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2502.07580},
}