multipath-nn
multipath-nn copied to clipboard
Experiments exploring dynamic routing in artificial neural networks
Multipath Neural Network Experiments
This repository contains scripts to run the experiments described in the ICML2017 paper Deciding How to Decide: Dynamic Routing in Artificial Neural Networks, and visualize the results. All scripts are intended to be run from the root directory.
Dependencies
- Python3, Numpy, and TensorFlow are required to train and test networks.
- Matplotlib and Seaborn are required to generate figures.
Library Modules
-
scripts/lib/data.py
defines theDataset
class that provides access to the datasets downloaded byscripts/prep-data
, and implements data augmentation. -
scripts/lib/layer_types.py
defines network layers that perform transformations and/or assign costs to network states. -
scripts/lib/net_types.py
defines statically-routed, actor, and critic networks. -
scripts/lib/desc.py
definesnet_desc
, a function that returns a serializable description of a network's structure and performance statistics, andrender_net_desc
, which returns a human-readable summary of this description. -
scripts/lib/serdes.py
defines network serialization and deserialization functions.
Experiment-Running Scripts
-
scripts/prep-data
downloads and formats MNIST, CIFAR-2, CIFAR-5, CIFAR-10, and the hybrid MNIST/CIFAR-10 dataset. The datasets are stored as.npz
archives in thedata/
directory. It is necessary to run this script before running any others. -
scripts/train-nets
trains and validates a set of networks.scripts/train-nets --help
prints a list of available experiments, with names in the form <dataset>-<net-type>[-<modifications>]. <dataset> corresponds to the name of a file in thedata
directory (after runningscripts/prep-data
). <net-type> is either "sr", "ac", or "cr", indicating statically-routed, actor, or critic nets, respectively. <modifications> indicates how the network architecture or training procedure will be modified (see the paper for details). The trained network parameters and performance statistics are stored in thenets/
directory. -
scripts/train-adaptive-nets
is analogous toscripts/train-nets
, except that it trains and validates a single network, with the ability to adapt to various costs of computation. -
scripts/arch_and_hypers.py
is a module that defines the architecture and hyperparameters used inscripts/train-nets
andscripts/train-adaptive-nets
.
Visualization Scripts
-
scripts/make-acc-eff-plots
writes accuracy-efficiency plots to thefigures/
directory, assuming the prerequisite experiments have been run. -
scripts/make-nlds
writes node-link diagrams to thefigures/
directory, assuming the prerequisite experiments have been run. -
scripts/make-routing-hists
writes routing histograms to thefigures/
directory, assuming the prerequisite experiments have been run. -
scripts/make-pres-figs
generates relatively simple figures, designed to be displayed in a live presentation.