image_classification_sota
image_classification_sota copied to clipboard
Training ImageNet / CIFAR models with sota strategies and fancy techniques such as ViT, KD, Rep, etc.
Image Classification SOTA
Image Classification SOTA
is an image classification toolbox based on PyTorch.
Updates
May 27, 2022
- Add knowledge distillation methods (KD and DIST).
March 24, 2022
- Support training strategies in DeiT (ViT).
March 11, 2022
- Release training code.
Supported Algorithms
Structural Re-parameterization (Rep)
- DBB (CVPR 2021) [paper] [original repo]
- DyRep (CVPR 2022) [README]
Knowledge Distillation (KD)
Requirements
torch>=1.0.1
torchvision
Getting Started
Prepare datasets
It is recommended to symlink the dataset root to image_classification_sota/data
. Then the file structure should be like
image_classification_sota
├── lib
├── tools
├── configs
├── data
│ ├── imagenet
│ │ ├── meta
│ │ ├── train
│ │ ├── val
│ ├── cifar
│ │ ├── cifar-10-batches-py
│ │ ├── cifar-100-python
Training configurations
-
Strategies
: The training strategies are configured using yaml file or arguments. Examples are inconfigs/strategies
directory.
Train a model
-
Training with a single GPU
python tools/train.py -c ${CONFIG} --model ${MODEL} [optional arguments]
-
Training with multiple GPUs
sh tools/dist_train.sh ${GPU_NUM} ${CONFIG} ${MODEL} [optional arguments]
-
For slurm users
sh tools/slurm_train.sh ${PARTITION} ${GPU_NUM} ${CONFIG} ${MODEL} [optional arguments]
Examples
-
Train ResNet-50 on ImageNet
sh tools/dist_train.sh 8 configs/strategies/resnet/resnet.yaml resnet50 --experiment imagenet_res50
-
Train MobileNetV2 on ImageNet
sh tools/dist_train.sh 8 configs/strategies/MBV2/mbv2.yaml nas_model --model-config configs/models/MobileNetV2/MobileNetV2.yaml --experiment imagenet_mbv2
-
Train VGG-16 on CIFAR-10
sh tools/dist_train.sh 1 configs/strategies/CIFAR/cifar.yaml nas_model --model-config configs/models/VGG/vgg16_cifar10.yaml --experiment cifar10_vgg16