VQVAE-Pytorch
VQVAE-Pytorch copied to clipboard
This repo implements VQVAE on mnist and as well as colored version of mnist images. It also implements simple LSTM for generating sample numbers using the encoder outputs of trained VQVAE
VQVAE Implementation in pytorch with generation using LSTM
This repository implements VQVAE for mnist and colored version of mnist and follows up with a simple LSTM for generating numbers.
VQVAE Explanation and Implementation Video
Quickstart
- Create a new conda environment with python 3.8 then run below commands
git clone https://github.com/explainingai-code/VQVAE-Pytorch.gitcd VQVAE-Pytorchpip install -r requirements.txt- For running a simple VQVAE with minimal code to understand the basics
python run_simple_vqvae.py - For playing around with VQVAE and training/inferencing the LSTM use the below commands passing the desired configuration file as the config argument
python -m tools.train_vqvaefor training vqvaepython -m tools.infer_vqvaefor generating reconstructions and encoder outputs for LSTM trainingpython -m tools.train_lstmfor training minimal LSTMpython -m tools.generate_imagesfor using the trained LSTM to generate some numbers
Configurations
config/vqvae_mnist.yaml- VQVAE for training on black and white mnist imagesconfig/vqvae_colored_mnist.yaml- VQVAE with more embedding vectors for training colored mnist images
Data preparation
For setting up the dataset: Follow - https://github.com/explainingai-code/Pytorch-VAE#data-preparation
Verify the data directory has the following structure:
VQVAE-Pytorch/data/train/images/{0/1/.../9}
*.png
VQVAE-Pytorch/data/test/images/{0/1/.../9}
*.png
Output
Outputs will be saved according to the configuration present in yaml files.
For every run a folder of task_name key in config will be created and output_train_dir will be created inside it.
During training of VQVAE the following output will be saved
- Best Model checkpoints(VQVAE and LSTM) in
task_namedirectory
During inference the following output will be saved
- Reconstructions for sample of test set in
task_name/output_train_dir/reconstruction.png - Encoder outputs on train set for LSTM training in
task_name/output_train_dir/mnist_encodings.pkl - LSTM generation output in
task_name/output_train_dir/generation_results.png
Sample Output for VQVAE
Running run_simple_vqvae should be very quick (as its very simple model) and give you below reconstructions (input in black black background and reconstruction in white background)
Running default config VQVAE for mnist should give you below reconstructions for both versions
Sample Generation Output after just 10 epochs Training the vqvae and lstm longer and more parameters(codebook size, codebook dimension, channels , lstm hidden dimension e.t.c) will give better results
Citations
@misc{oord2018neural,
title={Neural Discrete Representation Learning},
author={Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
year={2018},
eprint={1711.00937},
archivePrefix={arXiv},
primaryClass={cs.LG}
}