NeuralOptimalTransport
NeuralOptimalTransport copied to clipboard
PyTorch implementation of "Neural Optimal Transport" (ICLR 2023 Spotlight)
Neural Optimal Transport (NOT)
This is the official Python
implementation of the ICLR 2023 spotlight paper Neural Optimal Transport (NOT paper on openreview) by Alexander Korotin, Daniil Selikhanovych and Evgeny Burnaev.
The repository contains reproducible PyTorch
source code for computing optimal transport (OT) maps and plans for strong and weak transport costs in high dimensions with neural networks. Examples are provided for toy problems (1D, 2D) and for the unpaired image-to-image translation task for various pairs of datasets.
Repository structure
The implementation is GPU-based with the multi-GPU support. Tested with torch== 1.9.0
and 1-4 Tesla V100.
All the experiments are issued in the form of pretty self-explanatory jupyter notebooks (notebooks/
). For convenience, the majority of the evaluation output is preserved. Auxilary source code is moved to .py
modules (src/
).
-
notebooks/NOT_toy_1D.ipynb
- toy experiments in 1D (weak costs); -
notebooks/NOT_toy_2D.ipynb
- toy experiments in 2D (weak costs); -
notebooks/NOT_training_strong.ipynb
- unpaired image-to-image translation (one-to-one, strong costs); -
notebooks/NOT_training_weak.ipynb
- unpaired image-to-image translation (one-to-many, weak costs); -
notebooks/NOT_plots.ipynb
- plotting the translation results (pre-trained models are needed); -
stats/compute_stats.ipynb
- pre-compute InceptionV3 statistics to speed up test FID computation;
Setup
To run the notebooks, it is recommended to create a virtual environment using either conda
or venv
. Once the virtual environment is set up, install the required dependencies by running the following command:
pip install -r requirements.txt
Finally, make sure to install torch
and torchvision
. It is advisable to install these packages based on your system and CUDA
version. Please refer to the official website for detailed installation instructions.
Educational Materials
- Seminar and solutions on NOT with strong costs;
- Seminar and solutions on NOT with weak costs;
- Vector *.svg sources of the figures in the paper (use inkscape to edit);
Citation
@inproceedings{
korotin2023neural,
title={Neural Optimal Transport},
author={Korotin, Alexander and Selikhanovych, Daniil and Burnaev, Evgeny},
booktitle={International Conference on Learning Representations},
year={2023},
url={https://openreview.net/forum?id=d8CBRlWNkqH}
}
Application to Unpaired Image-to-Image Translation Task
The unpaired domain translation task can be posed as an OT problem. Our NOT algorithm is applicable here. It searches for a transport map with the minimal transport cost (we use $\ell^{2}$), i.e., it naturally aims to preserve certain image attributes during the translation.
Compared to the popular image-to-image translation models based on GANs or diffusion models, our method provides the following key advantages
- controlable amount of diversity in generated samples (without any duct tape or heuristics);
- better interpretability of the learned map.
Qualitative examples are shown below for various pairs of datasets (at resolutions $128\times 128$ and $64\times 64$).
One-to-one translation, strong OT
We show unpaired translition with NOT with the strong quadratic cost on outdoor → church, celeba (female) → anime, shoes → handbags, handbags → shoes, male → female, celeba (female) → anime, anime → shoes, anime → celeba (female) dataset pairs.
One-to-many translation, weak OT
We show unpaired translition with NOT with the $\gamma$-weak quadratic cost on handbags → shoes, celeba (female) → anime, outdoor → church, anime → shoes, shoes → handbags, anime → celeba (female) dataset pairs.
Controlling the amount of diversity
Our method offers a single parameter $\gamma\in[0,+\infty)$ in the weak quadratic cost to control the amount of diversity.
Datasets
-
Aligned anime faces (105GB) should be pre-processed with
datasets/preprocess.ipynb
; -
CelebA faces requires
datasets/list_attr_celeba.ipynb
; - Handbags, shoes, churches, outdoor datasets;
The dataloaders can be created by load_dataset
function from src/tools.py
. The latter four datasets get loaded directly to RAM.
Presentations
- Long talk (Part 1, 2) by Alex Korotin at AI in Industry Seminar at CS MSU (November 2023, RU)
- Long talk by Alex Korotin at Math in ML Seminar at Skoltech (November 2023, RU)
- Short talk by Alex Korotin at DataStart 2023 (November 2023, RU)
- Short Talk by Alex Korotin at FallML 2023 (November 2023, EN)
- Long talk (Part 1, 2) by Alex Korotin at Skoltech ML Summer School 2023 (August 2023, RU)
- Long talk (Part 1, 2) by Alex Korotin at AIRI Summer School 2023 (July 2023, RU)
- Talk by Alex Korotin at ICLR 2023 (May 2023, EN)
- Talk by Evgeny Burnaev at Scientific AI seminar (March 2023, RU)
- Short Talk by Evgeny Burnaev at Fall into ML school (02 November 2022, RU);
- Talk by Alexander Korotin at Seminar of "AI in Industry" association (13 October 2022, RU);
- Talk by Alexander Korotin at AIRI conference on AI 2022 (21 July 2022, RU);
- Talk by Alexander Korotin at TII seminar (09 Aug 2022, EN);
- Talk by Alexander Korotin at BayesGroup research seminar (20 May 2022, RU, slides);
- Short Talk by Alexander Korotin at Data Fusion 2022 Conference (15 April 2022, RU);
- Talk by Alexander Korotin at LEStart Seminar at Skoltech (24 February 2022, RU);
Related repositories
- Repository for Kernel Neural Optimal Transport paper (ICLR 2023).
- Repository for Kantorovich Strikes Back! Wasserstein GANs are not Optimal Transport? paper (NeurIPS 2022).
- Repository for Wasserstein Iterative Networks for Barycenter Estimation paper (NeurIPS 2022).
- Repository for Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark paper (NeurIPS 2021).
Credits
- Weights & Biases developer tools for machine learning;
- pytorch-fid repo to compute FID score;
- UNet architecture for transporter network;
- ResNet architectures for generator and discriminator;
- Inkscape for the awesome editor for vector graphics;