SubpopBench
SubpopBench copied to clipboard
[ICML 2023] Change is Hard: A Closer Look at Subpopulation Shift
Overview
SubpopBench is a benchmark of subpopulation shift. It is a living PyTorch suite containing benchmark datasets and algorithms for subpopulation shift, as introduced in Change is Hard: A Closer Look at Subpopulation Shift (Yang et al., ICML 2023).
Contents
Currently we support 13 datasets and ~20 algorithms that span different learning strategies. Feel free to send us a PR to add your algorithm / dataset for subpopulation shift.
Available Algorithms
The currently available algorithms are:
- Empirical Risk Minimization (ERM, Vapnik, 1998)
- Invariant Risk Minimization (IRM, Arjovsky et al., 2019)
- Group Distributionally Robust Optimization (GroupDRO, Sagawa et al., 2020)
- Conditional Value-at-Risk Distributionally Robust Optimization (CVaRDRO, Duchi and Namkoong, 2018)
- Mixup (Mixup, Zhang et al., 2018)
- Just Train Twice (JTT, Liu et al., 2021)
- Learning from Failure (LfF, Nam et al., 2020)
- Learning Invariant Predictors with Selective Augmentation (LISA, Yao et al., 2022)
- Deep Feature Reweighting (DFR, Kirichenko et al., 2022)
- Maximum Mean Discrepancy (MMD, Li et al., 2018)
- Deep Correlation Alignment (CORAL, Sun and Saenko, 2016)
- Data Re-Sampling (ReSample, Japkowicz, 2000)
- Cost-Sensitive Re-Weighting (ReWeight, Japkowicz, 2000)
- Square-Root Re-Weighting (SqrtReWeight, Japkowicz, 2000)
- Focal Loss (Focal, Lin et al., 2017)
- Class-Balanced Loss (CBLoss, Cui et al., 2019)
- Label-Distribution-Aware Margin Loss (LDAM, Cao et al., 2019)
- Balanced Softmax (BSoftmax, Ren et al., 2020)
- Classifier Re-Training (CRT, Kang et al., 2020)
Send us a PR to add your algorithm! Our implementations use the hyper-parameter grids described here.
Available Datasets
The currently available datasets are:
- ColoredMNIST (Arjovsky et al., 2019)
- Waterbirds (Wah et al., 2011)
- CelebA (Liu et al., 2015)
- MetaShift (Liang and Zou, 2022)
- CivilComments (Borkan et al., 2019) from the WILDS benchmark
- MultiNLI (Williams et al., 2017)
- MIMIC-CXR (Johnson et al., 2019)
- CheXpert (Irvin et al., 2019)
- CXRMultisite (Puli et al., 2021)
- MIMICNotes (Johnson et al., 2016)
- NICO++ (Zhang et al., 2022)
- ImageNetBG (Xiao et al., 2020)
- Living17 (Santurkar et al., 2020) from the BREEDS benchmark
Send us a PR to add your dataset! You can follow the dataset format described here.
Model Architectures & Pretraining Methods
The supported image architectures are:
- ResNet-50 on ImageNet-1K using supervised pretraining (
resnet_sup_in1k) - ResNet-50 on ImageNet-21K using supervised pretraining (
resnet_sup_in21k, Ridnik et al., 2021) - ResNet-50 on ImageNet-1K using SimCLR (
resnet_simclr_in1k, Chen et al., 2020) - ResNet-50 on ImageNet-1K using Barlow Twins (
resnet_barlow_in1k, Zbontar et al., 2021) - ResNet-50 on ImageNet-1K using DINO (
resnet_dino_in1k, Caron et al., 2021) - ViT-B on ImageNet-1K using supervised pretraining (
vit_sup_in1k, Steiner et al., 2021) - ViT-B on ImageNet-21K using supervised pretraining (
vit_sup_in21k, Steiner et al., 2021) - ViT-B from OpenAI CLIP (
vit_clip_oai, Radford et al., 2021) - ViT-B pretrained using CLIP on LAION-2B (
vit_clip_laion, OpenCLIP) - ViT-B on SWAG using weakly supervised pretraining (
vit_sup_swag, Singh et al., 2022) - ViT-B on ImageNet-1K using DINO (
vit_dino_in1k, Caron et al., 2021)
The supported text architectures are:
- BERT-base-uncased (
bert-base-uncased, Devlin et al., 2018) - GPT-2 (
gpt2, Radford et al., 2019) - RoBERTa-base-uncased (
xlm-roberta-base, Liu et al., 2019) - SciBERT (
allenai/scibert_scivocab_uncased, Beltagy et al., 2019) - DistilBERT-uncased (
distilbert-base-uncased, Sanh et al., 2019)
Note that text architectures are only compatible with CivilComments.
Subpopulation Shift Scenarios
We characterize four basic types of subpopulation shift using our framework, and categorize each dataset into its most dominant shift type.
- Spurious Correlations (SC): certain $a$ is spuriously correlated with $y$ in training but not in testing.
- Attribute Imbalance (AI): certain attributes are sampled with a much smaller probability than others in $p_{\text{train}}$, but not in $p_{\text{test}}$.
- Class Imbalance (CI): certain (minority) classes are underrepresented in $p_{\text{train}}$, but not in $p_{\text{test}}$.
- Attribute Generalization (AG): certain attributes can be totally missing in $p_{\text{train}}$, but present in $p_{\text{test}}$.
Evaluation Metrics
We include a variety of metrics aiming for a thorough evaluation from different aspects:
- Average Accuracy & Worst Accuracy
- Average Precision & Worst Precision
- Average F1-score & Worst F1-score
- Adjusted Accuracy
- Balanced Accuracy
- AUROC & AUPRC
- Expected Calibration Error (ECE)
Model Selection Criteria
We highlight the impact of whether attribute is known in (1) training set and (2) validation set,
where the former is specified by --train_attr in train.py,
and the latter is specified by model selection criteria.
We show a few important selection criteria:
OracleWorstAcc: Picks the best test-set worst-group accuracy (oracle)ValWorstAccAttributeYes: Picks the best val-set worst-group accuracy (attributes known in validation)ValWorstAccAttributeNo: Picks the best val-set worst-class accuracy (attributes unknown in validation; group degenerates to class)
Getting Started
Installation
Prerequisites
Run the following commands to clone this repo and create the Conda environment:
git clone [email protected]:YyzHarry/SubpopBench.git
cd SubpopBench/
conda env create -f environment.yml
conda activate subpop_bench
Downloading Data
Download the original datasets and generate corresponding metadata in your data_path:
python -m subpopbench.scripts.download --data_path <data_path> --download
For MIMICNoFinding, CheXpertNoFinding, CXRMultisite, and MIMICNotes, see MedicalData.md for instructions for downloading the datasets manually.
Code Overview
Main Files
train.py: main training scriptsweep.py: launch a sweep with all selected algorithms (provided insubpopbench/learning/algorithms.py) on all subpopulation shift datasetscollect_results.py: collect sweep results to automatically generate result tables (as in the paper)
Main Arguments
- train.py:
--dataset: name of chosen subpopulation dataset--algorithm: choose algorithm used for running--train_attr: whether attributes are known or not during training (yesorno)--data_dir: data path--output_dir: output path--output_folder_name: output folder name (underoutput_dir) for the current run--hparams_seed: seed for different hyper-parameters--seed: seed for different runs--stage1_folder&--stage1_algo: arguments for two-stage algorithms--image_arch&--text_arch: model architecture and source of initial model weights (text architectures only compatible withCivilComments)
- sweep.py:
--n_hparams: how many hparams to run for each <dataset, algorithm> pair--best_hp&--n_trials: after sweeping hparams, fix best hparam and run trials with different seeds
Usage
Train a single model (with unknown attributes)
python -m subpopbench.train \
--algorithm <algo> \
--dataset <dset> \
--train_attr no \
--data_dir <data_path> \
--output_dir <output_path> \
--output_folder_name <output_folder_name>
Train a model using 2-stage methods, e.g., DFR (with known attributes)
python -m subpopbench.train \
--algorithm DFR \
--dataset <dset> \
--train_attr yes \
--data_dir <data_path> \
--output_dir <output_path> \
--output_folder_name <output_folder_name> \
--stage1_folder <stage1_model_folder> \
--stage1_algo <stage1_algo>
Launch a sweep with different hparams (with unknown attributes)
python -m subpopbench.sweep launch \
--algorithms <...> \
--dataset <...> \
--train_attr no \
--n_hparams <num_of_hparams> \
--n_trials 1
Launch a sweep after fixing hparam with different seeds (with unknown attributes)
python -m subpopbench.sweep launch \
--algorithms <...> \
--dataset <...> \
--train_attr no \
--best_hp \
--input_folder <...> \
--n_trials <num_of_trials>
Collect the results of your sweep
python -m subpopbench.scripts.collect_results --input_dir <...>
Updates
- [07/2023] Check out the Oral talk video (10 mins) for our ICML paper.
- [05/2023] Paper accepted to ICML 2023.
- [02/2023] arXiv version posted. Code is released.
Acknowledgements
This code is partly based on the open-source implementations from DomainBed, spurious_feature_learning, and multi-domain-imbalance.
Citation
If you find this code or idea useful, please cite our work:
@inproceedings{yang2023change,
title={Change is Hard: A Closer Look at Subpopulation Shift},
author={Yang, Yuzhe and Zhang, Haoran and Katabi, Dina and Ghassemi, Marzyeh},
booktitle={International Conference on Machine Learning},
year={2023}
}
Contact
If you have any questions, feel free to contact us through email ([email protected] & [email protected]) or GitHub issues. Enjoy!