variational_wgf icon indicating copy to clipboard operation
variational_wgf copied to clipboard

[ICML2022] Variational Wasserstein gradient flow

Variational Wasserstein gradient flow

This is the official Python implementation of the paper Variational Wasserstein gradient flow (paper on arXiv) by Jiaojiao Fan, Qinsheng Zhang, Amirhossein Taghvaei and Yongxin Chen.

The repository contains reproducible PyTorch source code for computing Wasserstein gradient flow with variational estimation of target functional in high dimension.

Repository structure

The codebase is tested on CUDA version 11.4 and PyTorch version 1.10.1+cu113.

To reproduce the experiments except image geneation, go to toy folder and follow the instructions in toy/README.md

cd toy

To reproduce the experiment of image geneation, go to image folder and follow the instructions in image/README.md

cd image

Citation

@inproceedings{
  fan2022variational,
  title={Variational Wasserstein gradient flow},
  author={Fan, Jiaojiao and Zhang, Qinsheng and Taghvaei, Amirhossein and Chen, Yongxin},
  booktitle={International Conference on Machine Learning},
  year={2022}
}