Wasserstein2Barycenters
Wasserstein2Barycenters copied to clipboard
PyTorch implementation of the paper "Continuous Wasserstein-2 Barycenter Estimation without Minimax Optimization" (ICLR 2021)
Continuous Wasserstein-2 Barycenter Estimation without Minimax Optimization
This is the official Python
implementation of the ICLR 2021 paper Continuous Wasserstein-2 Barycenter Estimation without Minimax Optimization (paper on openreview) by Alexander Korotin, Lingxiao Li, Justin Solomon and Evgeny Burnaev
The repository contains the fully-reproducible PyTorch
source code for computing Wasserstein-2 barycenters in high dimensions via the non-minimax method (proposed in the paper) by using input convex neural networks. Examples are provided for various toy examples and the example of averaging image color palettes.
Citation
@inproceedings{
korotin2021continuous,
title={Continuous Wasserstein-2 Barycenter Estimation without Minimax Optimization},
author={Alexander Korotin and Lingxiao Li and Justin Solomon and Evgeny Burnaev},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=3tFAs5E-Pe}
}
Prerequisites
The implementation is GPU-based. Single GPU (~GTX 1080 ti) is enough to run each particular experiment. Tested with
torch==1.3.0
The code might not run as intended in newer torch
versions.
Related repositories
- Repository for Kantorovich Strikes Back! Wasserstein GANs are not Optimal Transport? paper.
- Repository for Wasserstein-2 Generative Networks paper.
- Repository for Continuous Regularized Wasserstein Barycenters paper.
- Repository for Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark paper.
- Repository for Large-Scale Wasserstein Gradient Flows paper.
Repository structure
The code for running the experiments are located in self-contained jupyter notebooks (notebooks/
). For convenience, the majority of the evaluation output is preserved. Other auxilary source code is moved to .py
modules (src/
).
Experiments
-
notebooks/CW2B_toy_experiments.ipynb.ipynb
-- toy experiments (in dimensions up to 256) and subset posterior aggregation. -
notebooks/CW2B_averaging_color_palettes.ipynb
-- averaging color palettes of images.
Input convex neural networks
-
src/icnn.py
-- modules for Input Convex Neural Network architectures (DenseICNN);
Poster
-
poster/CW2B_poster.png
-- poster (landscape format) -
poster/CW2B_poster.svg
-- source file for the poster
Visualized Results
The provided code is capable of generating the following visual results that are included in the paper.
Toy Experiments (2D)
Example below contains 4 initial distributions (on the left), the ground truth barycenter (in the middle) and the barycenter computed by each of 4 potentials recovered by our algithm (on the right).
Color Palette Averaging (3D)
Example below demonstrates barycenters of RGB (3D) color palettes of three images.
Original images and color palettes
"Averaged" images and color palettes (estimated by each of three potentials computed by our algorithm)