proximity_vi
proximity_vi copied to clipboard
This code accompanies the proximity variational inference paper.
Proximity Variational Inference
This code accompanies the proximity variational inference paper: https://arxiv.org/abs/1705.08931
If you use this code, please cite us:
@article{altosaar2017proximity,
author={Altosaar, J and Ranganath, R and Blei, DM},
eprint={arXiv:1311.1704},
title={Proximity Variational Inference},
url={https://arxiv.org/abs/1705.08931},
year={2017}
}
The promise: Variational inference (left) is sensitive to initialization. Proximity variational inference (right) can help correct this.
Data
Get the binarized MNIST dataset from Hugo & Larochelle (2011), write it to /tmp/binarized_mnist.hdf5
.
python get_binary_mnist.py
Environment
I recommend anaconda: brew cask install anaconda
on a mac, bash installer otherwise. To use the same environment:
conda env create -f environment.yml # may need to edit to choose between CPU or GPU version of tensorflow
source activate proximity_vi
The code assumes you have set the following environment variables. This enables easy switching between local and remote workstations.
> export DAT=/tmp
> export LOG=/tmp
Sigmoid belief network experiment
This benchmarks proximity variational inference against deterministic annealing and vanilla variational inference, with good initialization and bad initialization (Tables 1 and 2 in the paper).
Each experiment takes about half a day on a Tesla P100 GPU:
./sigmoid_belief_network_grid.sh
# List final estimates of the ELBO and marginal likelihood
tail -n 1 $LOG/proximity_vi/*/*/*.log
# View training statistics on tensorboard
tensorboard --logdir $LOG/proximity_vi
Variational autoencoder experiment
This tests the orthogonal proximity statistic to make optimization easier in a variational autoencoder. (Table 3 in the paper)
Each run takes a few minutes on a Tesla P100 GPU:
./deep_latent_gaussian_model_grid.sh
# List final estimates of the ELBO and marginal likelihood
tail -n 1 $LOG/proximity_vi/*/*/*.log
# View training statistics on tensorboard
tensorboard --logdir $LOG/proximity_vi
Support
Please email me with any questions: [email protected].