TEC
TEC copied to clipboard
Towards Sustainable Self-supervised Learning
The official implementation of paper: Towards Sustainable Self-supervised Learning Target-Enhanced Conditional (TEC) pretraining: a Faster and Stronger self-supervised pretraining method.
Introduction
Although increasingly training-expensive, most self-supervised learning (SSL) models have repeatedly been trained from scratch but not fully utilized, since only a few SOTAs are employed for downstream tasks. In this work, we explore a sustainable SSL framework with two major challenges: i) learning a stronger new SSL model based on the existing pretrained SSL model, also called as base model, in a cost-friendly manner, ii) allowing the training of the new model to be compatible with various base models.
We propose a Target-Enhanced Conditional (TEC) scheme which introduces two components to the existing mask-reconstruction based SSL. Firstly, we propose patch-relation enhanced targets which enhances the target given by base model and encourages the new model to learn semantic-relation knowledge from the base model by using incomplete inputs. This hardening and target-enhancing help the new model surpass the base model, since they enforce additional patch relation modeling to handle incomplete input. Secondly, we introduce a conditional adapter that adaptively adjusts new model prediction to align with the target of different base models.
TEC Performance
| Method | Network | Base model | Pretrain data | Epoch | Top 1 acc. | PT Weights | Logs |
|---|---|---|---|---|---|---|---|
| TEC | ViT-B | MAE 300ep | ImageNet-1k | 100 | 83.9 | weights | PT FT |
| TEC | ViT-B | MAE | ImageNet-1k | 300 | 84.7 | weights | PT FT |
| TEC | ViT-B | MAE | ImageNet-1k | 800 | 84.8 | weights | PT FT |
| TEC | ViT-B | iBoT | ImageNet-1k | 300 | 84.8 | weights | PT FT |
| TEC | ViT-B | iBoT | ImageNet-1k | 800 | 85.1 | weights | PT FT |
| TEC | ViT-L | MAE | ImageNet-1k | 300 | 86.5 | weights | PT FT |
Training
Requirement
-
timm==0.3.2 pytorch 1.8.1
-
Download the pretrained SSL model of MAE/iBoT from the official repo. Change the pretrained model path in
models/pretrained_basemodels.pyfile.
Pretraining and Finetuning
MAE-ViT-B base model and ViT-B new model.
300 epoch model pretraining:
python -m torch.distributed.launch --nproc_per_node=8 --use_env main_pretrain.py \
--mlp_token \
--pred_att \
--topkatt 15 \
--att_tau 1.8 \
--basemodel mae1kbase \
--model mae_vit_base_patch16 \
--last_layers 2 \
--batch_size 256 \
--mask_ratio 0.75 \
--epochs 300 \
--warmup_epochs 40 \
--blr 1.5e-4 --weight_decay 0.05 \
--accum_iter 2 \
--data_path /dataset/imagenet-raw \
--output_dir output_dir; \
300 epoch model finetuning:
python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \
--mlp_token \
--accum_iter 1 \
--batch_size 128 \
--model vit_base_patch16 \
--finetune output_dir/checkpoint-299.pth \
--epochs 100 \
--blr 5e-4 --layer_decay 0.65 \
--warmup_epochs 20 \
--weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
--dist_eval --data_path /dataset/imagenet-raw \
--output_dir output_dir_finetune; \
800 epoch model pretraining:
python -m torch.distributed.launch --nproc_per_node=8 --use_env main_pretrain.py \
--mlp_token \
--pred_att \
--topkatt 15 \
--att_tau 1.8 \
--basemodel mae1kbase \
--model mae_vit_base_patch16 \
--last_layers 2 \
--batch_size 256 \
--mask_ratio 0.75 \
--epochs 800 \
--warmup_epochs 40 \
--blr 1.5e-4 --weight_decay 0.05 \
--accum_iter 2 \
--data_path /dataset/imagenet-raw \
--output_dir output_dir; \
800 epoch model finetuning:
python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \
--mlp_token \
--accum_iter 1 \
--batch_size 128 \
--model vit_base_patch16 \
--finetune output_dir/checkpoint-799.pth \
--epochs 100 \
--blr 5e-4 --layer_decay 0.55 \
--warmup_epochs 20 \
--weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
--dist_eval --data_path /dataset/imagenet-raw \
--output_dir output_dir_finetune; \
iBoT-ViT-B base model and ViT-B new model.
300 epoch pretraining:
python -m torch.distributed.launch --nproc_per_node=8 --use_env main_pretrain.py \
--mlp_token \
--pred_att \
--topkatt 9 \
--att_tau 1.0 \
--basemodel ibot1kbase \
--model mae_vit_base_patch16 \
--last_layers 2 \
--batch_size 256 \
--mask_ratio 0.75 \
--epochs 300 \
--warmup_epochs 40 \
--blr 1.5e-4 --weight_decay 0.05 \
--accum_iter 2 \
--data_path /dataset/imagenet-raw \
--output_dir output_dir; \
300 epoch finetuning:
python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \
--mlp_token \
--accum_iter 1 \
--batch_size 128 \
--model vit_base_patch16 \
--finetune output_dir/checkpoint-799.pth \
--epochs 100 \
--blr 5e-4 --layer_decay 0.50 \
--warmup_epochs 5 \
--weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
--dist_eval --data_path /dataset/imagenet-raw \
--output_dir output_dir_finetune; \
800 epoch pretraining:
python -m torch.distributed.launch --nproc_per_node=8 --use_env main_pretrain.py \
--mlp_token \
--pred_att \
--topkatt 9 \
--att_tau 1.0 \
--basemodel ibot1kbase \
--model mae_vit_base_patch16 \
--last_layers 2 \
--batch_size 256 \
--mask_ratio 0.75 \
--epochs 800 \
--warmup_epochs 40 \
--blr 1.5e-4 --weight_decay 0.05 \
--accum_iter 2 \
--data_path /dataset/imagenet-raw \
--output_dir output_dir; \
800 epoch finetuning:
python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \
--mlp_token \
--accum_iter 1 \
--batch_size 128 \
--model vit_base_patch16 \
--finetune output_dir/checkpoint-799.pth \
--epochs 100 \
--blr 5e-4 --layer_decay 0.55 \
--warmup_epochs 20 \
--weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
--dist_eval --data_path /dataset/imagenet-raw \
--output_dir output_dir_finetune; \
MAE-ViT-L base model and ViT-L new model.
pretraining:
python -m torch.distributed.launch --nproc_per_node=8 --use_env main_pretrain.py \
--mlp_token \
--pred_att \
--topkatt 15 \
--att_tau 1.4 \
--basemodel mae1klarge \
--model mae_vit_large_patch16 \
--last_layers 2 \
--batch_size 256 \
--mask_ratio 0.75 \
--epochs 300 \
--warmup_epochs 40 \
--blr 1.5e-4 --weight_decay 0.05 \
--accum_iter 2 \
--data_path /dataset/imagenet-raw \
--output_dir output_dir; \
finetuning:
python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \
--mlp_token \
--accum_iter 2 \
--batch_size 64 \
--model vit_large_patch16 \
--finetune output_dir/checkpoint-299.pth \
--epochs 50 \
--blr 1e-3 --layer_decay 0.65 \
--min_lr 1e-5 \
--warmup_epochs 5 \
--weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
--dist_eval --data_path /dataset/imagenet-raw \
--output_dir output_dir_finetune; \
MAE-ViT-B-300ep basemodel and ViT-L new model 100ep.
pretraining:
python -m torch.distributed.launch --nproc_per_node=8 --use_env main_pretrain.py \
--mlp_token \
--pred_att \
--topkatt 15 \
--att_tau 1.8 \
--basemodel mae1k300ep \
--model mae_vit_base_patch16 \
--last_layers 2 \
--batch_size 256 \
--mask_ratio 0.75 \
--epochs 100 \
--warmup_epochs 40 \
--blr 1.5e-4 --weight_decay 0.05 \
--accum_iter 2 \
--data_path /dataset/imagenet-raw \
--output_dir output_dir; \
finetuning:
python -m torch.distributed.launch --nproc_per_node=8 main_finetune.py \
--mlp_token \
--accum_iter 1 \
--batch_size 128 \
--model vit_base_patch16 \
--finetune output_dir/checkpoint-99.pth \
--epochs 100 \
--blr 5e-4 --layer_decay 0.65 \
--weight_decay 0.05 --drop_path 0.1 --mixup 0.8 --cutmix 1.0 --reprob 0.25 \
--dist_eval --data_path /dataset/imagenet-raw \
--output_dir output_dir_finetune; \
Convert checkpoint
To use the pretrained checkpoint for downstream tasks, you need to convert the checkpoint as follows: (Our provided checkpoints are already been converted.)
Convert checkpoint.
python weight_convert.py \
--mlp_token \
--model vit_base_patch16 \
--resume path_to_pretrained_model \
--output_dir output_dir_convert \
--ckptname output_ckpt_name.pth
Citation
@article{gao2022towards,
title={Towards Sustainable Self-supervised Learning},
author={Gao, Shanghua and Zhou, Pan and Cheng, Ming-Ming and Yan, Shuicheng},
journal={arXiv preprint arXiv:2210.11016},
year={2022}
}
Acknowledgement
This codebase is build based on the MAE codebase. Thanks!