dsit icon indicating copy to clipboard operation
dsit copied to clipboard

Implementation of the models and datasets used in "An Information-theoretic Approach to Distribution Shifts"

An Information-theoretic Approach to Distribution Shifts

Distribution Shift example

This repository contains the code used for the results reported in the paper An Information-theoretic Approach to Distribution Shifts.

The code includes the implementation of:

  • Three variants of the CMNIST dataset:
    • CMNIST
    • d-CMNIST
    • y-CMNIST
  • Discrete Models
    • Mutual information computation and optimization for discrete distributions defined through tensors.
    • Discrete encoders defined through a normalized probability mapping matrix.
  • Neural Network-based Models
    • Variational Information Bottleneck (VIB)
    • Domain Adversarial Neural Networks (DANN)
    • Invariant Risk Minimization (IRM)
    • Variance-based Risk Extrapolation (VREx)
    • Conditional Domain Adversarial Neural Networks (CDANN)

The implementations are based on the Pytorch Lightning framework, hydra is used for configuration management, while the wandb library handles logging and hyper-parameter sweeps (TensorBoard logging is also available).

Further information on the framework and design paradigms used in this implementation can be found here.

Requirements

Conda Environment

The required packages can be installed using conda:

conda env create -f environment.yml

Once the environment has been created, it can be activated with:

conda activate dl-kit

Device configuration

Set a device name by defining the DEVICE_NAME environment variable.

export DEVICE_NAME=<A_VALID_DEVICE_NAME>

This will enable the corresponding configuration (in config/device). The files laptop.yaml and workstation.yaml contain two examples of deployment configurations containing:

  • dataset and experiments paths
  • hardware specific configuration (number of GPUs and CPU cores).

It is possible to create new customized device configuration by

  1. Creating a my_device.yaml configuration file containing:
data_root: <path_to_dataset_root>             # Path
experiments_root: <path_to_experiment_root>   # Path
download_files: <enable_dataset_download>     # True/False
num_workers: <number_of_CPU_for_data_loading> # Integer
pin_memory: <enable_pin_memory>               # True/False
gpus: <number_of_GPUs>                        # Integer                             
auto_select_gpus: <enable_auto_GPU_selection> # True/False
  1. Setting the DEVICE_NAME environment variable to my_device:
export DEVICE_NAME=my_device

Weights & Bias logging (and Sweep)

In order to use the Weights & Bias logging run:

wandb init

This operation is optional if the desired logging option is TensorBoard, which can be enabled using the flag logging=tensorboard when running the training script.

Datasets

Graphical Models

The code contains the implementation of the three variations of the Colored MNIST dataset (CMNIST,d-CMNIST, y-CMNIST), and three corresponding versions used for validation and hyper-parameter search (CMNIST_valid,d-CMNIST_valid, y-CMNIST_valid). In the former versions, models are trained on the train+validation sets and evaluated on the test set. In the validation settings, models are trained on the train set and evaluated on the (disjoint) validation set.

CMNIST_samples

The dataset.ipynb contains a simple example of usage for the CMNIST, d-CMNIST and y-CMNIST datasets. The current implementation is based on the PyTorch Dataset class to promote reproducibility and re-usability.

Training

Discrete models and direct criteria optimization

Criteria Discrete trajectories

The discrete models can be trained using the command

python train_discrete.py

which will produce a .csv (results/discrete.csv by default) containing the values of train and test cross-entropy for model optimized following the Information Bottleneck, Independence, Sufficiency and Separation criteria on the CMNIST, d-CMNIST and y-CMNIST datasets for different regularization strength.

Similarly to the neural-network models training, the hyper-parameters can be changed either by editing the discrete_config/config.yaml file, or by specifying the corresponding flags when launching the training script.

The error_decomposition.ipynb contains a detailed explanation regarding how test error can be de-composed into test information loss and latent test error. This notebook also includes details regarding the training procedure for discrete models and how the results reported in the paper have been obtained.

Error Decomposition

Neural Network Models

MLP results

Each model can be trained using the train.py script using the following command

python train.py <FLAGS>

Where <FLAGS> refers to any configuration flag defined in config and handled by hydra. If no experiment is specified, by default, the script will train a VIB model on the CMNIST dataset. Other models can be trained specifying the model flag (VIB,DANN,IRM,CDANN,VREx), while datasets can be changed using the data flag (CMNIST,d-CMNIST, y-CMNIST).

For example, to run the CDANN model on y-CMNIST, one can use:

python train.py model=CDANN data=y-CMNIST

Other flags allow to change optimization parameters, logging, evaluation and regularization schedule. The command

python train.py model=CDANN data=y-CMNIST params.lr=1e-3 params.n_adversarial_steps=20 train_for="2 hours"

will train a CDANN model on y-CMNIST with learning rate 10^{-3}, using 20 trainin steps of the discriminator for each generator step with TensorBoard logging for a total training time of 2 hours. Note that the train_for flag allows to specify the training duration in iterations, epochs or even seconds, minutes, hours.

Weights & Bias Sweeps

The hyper-parameters sweeps used to produce the plots in the paper can be found in the sweeps directory.

Running sweeps requires the initialization of the Weights & Bias. To run all the experiments used to produce the plot in Figure 3 (bottom row), one can use:

wandb sweep sweeps/sweep_MLP.yml

which will return a unique <SWEEP_ID>.

Each sweep agent can then be launched using:

wandb agent <SWEEP_ID>

from the project directory.