neuralsort copied to clipboard
Code for "Stochastic Optimization of Sorting Networks using Continuous Relaxations", ICLR 2019.
Stochastic Optimization of Sorting Networks via Continuous Relaxations
This repository provides a reference implementation for learning NeuralSort-based models as described in the paper:
Stochastic Optimization of Sorting Networks via Continuous Relaxations
Aditya Grover, Eric Wang, Aaron Zweig and Stefano Ermon.
International Conference on Learning Representations (ICLR), 2019.
The codebase is implemented in Python 3.7. To install the necessary requirements, run the following commands:
pip install -r requirements.txt
The scripts for downloading and loading the MNIST and CIFAR10 datasets are included in the datasets_loader
folder. These scripts will be called automatically the first time the
script is run.
Learning and inference of differentiable kNN models is handled by the pytorch/
script which provides the following command-line arguments:
--k INT number of nearest neighbors
--tau FLOAT temperature of sorting operator
--nloglr FLOAT negative log10 of learning rate
--method STRING one of 'deterministic', 'stochastic'
--dataset STRING one of 'mnist', 'fashion-mnist', 'cifar10'
--num_train_queries INT number of queries to evaluate during training.
--num_train_neighbors INT number of neighbors to consider during training.
--num_samples INT number of samples for stochastic methods
--num_epochs INT number of epochs to train
-resume start a new model, instead of loading an older one
Learning and inference of quantile-regression models is handled by the tf/
script, which provides the following command-line arguments:
--M INT minibatch size
--n INT number of elements to compare at a time
--l INT number of digits in each multi-mnist dataset element
--tau FLOAT temperature (either of sinkhorn or neuralsort relaxation)
--method STRING one of 'vanilla', 'sinkhorn', 'gumbel_sinkhorn', 'deterministic_neuralsort', 'stochastic_neuralsort'
--n_s INT number of samples for stochastic methods
--num_epochs INT number of epochs to train
--lr FLOAT initial learning rate
Learning and inference of sorting models is handled by the tf/
script, which provides the following command-line arguments:
--M INT minibatch size
--n INT number of elements to compare at a time
--l INT number of digits in each multi-mnist dataset element
--tau FLOAT temperature (either of sinkhorn or neuralsort relaxation)
--method STRING one of 'vanilla', 'sinkhorn', 'gumbel_sinkhorn', 'deterministic_neuralsort', 'stochastic_neuralsort'
--n_s INT number of samples for stochastic methods
--num_epochs INT number of epochs to train
--lr FLOAT initial learning rate
Training dKNN model to classify CIFAR10 digits
cd pytorch
python --k=9 --tau=64 --nloglr=3 --method=deterministic --dataset=cifar10
Training quantile regression model to predict the median of sets of nine 5-digit numbers
cd tf
python --M=100 --n=9 --l=5 --method=deterministic_neuralsort
Training sorting model to sort sets of five 4-digit numbers
cd tf
python --M=100 --n=5 --l=4 --method=deterministic_neuralsort
If you find NeuralSort useful in your research, please consider citing the following paper:
title={Stochastic Optimization of Sorting Networks via Continuous Relaxations},
author={Aditya Grover and Eric Wang and Aaron Zweig and Stefano Ermon},
booktitle={International Conference on Learning Representations},