HashGAN-Pytorch copied to clipboard
A pytorch implementation of Unsupervised Deep Generative Adversarial Hashing Network, CVPR'18
This is a pytorch implementation of Unsupervised Deep Generative Adversarial Hashing Network, CVPR'18 for CIFAR10 and MNIST dataset.
This code was written to be run on Linux.
Python > 3.6
Using conda is recommended: https://docs.anaconda.com/anaconda/install/linux/
To install: https://pytorch.org/get-started/locally/
Run this in terminal:
conda install numpy
Run this in terminal:
conda install tqdm
Run this in terminal:
git clone https://github.com/8uos/HashGAN-Pytorch
The simplest way is just run this in terminal:
python HashGAN.py --train
There are some possible additional arguments:
: The size of minibatch. Default value is 100.
: The length of a continuous part of input random variable. Default value is 128.
: The length of a binary part of input random variable. Default value is 16.
: Initial learning rate. Default value is 9e-04.
: Final learning rate. Default value is 3e-04.
: The number of epochs to train. Default value is 100.
: The id of gpu to use. Default value is 0.
: The name of dataset to use. 'cifar10' and 'mnist' are possible, and default value is 'mnist'.
: The path to generator model to load. If this is None, generator will be randomly initialized. Default value is None.
: The path to discriminator model to load. If this is None, discriminator will be randomly initialized. Default value is None.
: The name of the save directory. Everything will be saved in results/save_dir
. Default value is 'temp'.
: Interval to print the losses.
: The location of the dataset. If the dataset does not exists in data_dir, the dataset will be downloaded. Default value is './data'.
python HashGAN.py --eval \
--G_dict=Path/to/generator/dict/to/evaluate \
The possible additional arguments are identical to the ones above.
HashGAN (class)
init (method)
Defines the HashGAN network.
loss_D (method)
Computes the dicriminator and encoder losses and stores them as attributes of HashGAN class.
loss_G (method)
Computes the feature matching loss and stores it as an attribute of HashGAN class.
step_opt (method)
Computes gradient and step optimizer.
generate_code_label (method)
Generates codes using current encoder and labels of all datapoints in given dataloader.
eval (method)
Evaluates the hashgan network with given query set and database set.
train (method)
Train the network.
Define the net, train, evaluate
get_trmat (function)
Builds transform matrix to compute consistent bit loss.
set_input (function)
Makes input random variable consisting of a continuous part and a binary part of input random variable.
get_prec_topn (function)
Computes precision@topn with given query and database codes.
bit_entropy (function)
Computes the entropy of each bit of the given code.
Generator (class)
The definition of the model of generator network.
Discriminator (class)
The definition of the model of discriminator and encoder network.
- K.dizaji, F.Zheng, N.Nourabadi, Y.Yang, C.Deng, H.Huang. “Unsupervised Deep Generative Adversarial Hashing Network”, CVPR, 2018.