turtle
turtle copied to clipboard
[ICML 2024] Let Go of Your Labels with Unsupervised Transfer
Let Go of Your Labels with Unsupervised Transfer
Artyom Gadetsky*, Yulun Jiang*, Maria Brbić
Project page | Paper | BibTeX
This repo contains the source code of 🐢 TURTLE, an unupervised learning algorithm written in PyTorch. 🔥 TURTLE achieves state-of-the-art unsupervised performance on the variety of benchmark datasets. For more details please check our paper Let Go of Your Labels with Unsupervised Transfer (ICML '24).
Dependencies
The code is built with the following libraries
- PyTorch - 2.2.1
- torchvision - 0.17.1
- numpy
- scipy
- scikit-learn
- clip
- tqdm
- cuml - 24.02
To install cuml, you can follow the instructions on this page.
Quick Start
In our paper, we consider 26 vision datasets studied in (Radford et al. 2021) and 9 different foundation models. As a running example, we present the full pipeline to train TURTLE on the CIFAR100 dataset.
- Precompute representations and save ground truth labels for the dataset
python precompute_representations.py --dataset cifar100 --phis clipvitL14
python precompute_representations.py --dataset cifar100 --phis dinov2
python precompute_labels.py --dataset cifar100
- Train TURTLE with 2 representation spaces
python run_turtle.py --dataset cifar100 --phis clipvitL14 dinov2
or with the single representation space
python run_turtle.py --dataset cifar100 --phis clipvitL14
python run_turtle.py --dataset cifar100 --phis dinov2
The results and the checkpoints will be saved at ./data/results, ./data/task_checkpoints. You can also use --root_dir in all scripts to specify root directory instead of ./data which is used by default.
Data Preparation
Most datasets can be automatically downloaded by running precompute_representations.py and precompute_labels.py. However, some of the datasets require manual downloading. Please check dataset_preparation/data_utils.py for guide to prepare all the datasets used in our paper.
As an example, to prepare pets dataset that is not directly available at torchvision.datasets, one can run:
python dataset_preparation/prepare_pets.py -i ./data/datasets/pets -o ./data/datasets/pets -d
to download and extract the dataset at ./data/datasets/pets.
After downloading the dataset, run the following command to precompute the representations and labels:
python precompute_representations.py --dataset ${DATASET} --phis ${REPRESENTATION}
python precompute_labels.py --dataset ${DATASET}
Datasets and representations covered in this repo:
- 26 datasets:
food101, cifar10, cifar100, birdsnap, sun397, cars, aircraft, dtd, pets, caltech101, flowers, mnist, fer2013, stl10, eurosat, resisc45, gtsrb, kitti, country211, pcam, ucf101, kinetics700, clevr, hatefulmemes, sst, imagenet. - 9 representations:
clipRN50, clipRN101, clipRN50x4, clipRN50x16, clipRN50x64, clipvitB32, clipvitB16, clipvitL14, dinov2.
Running TURTLE
Once the representations and labels are precomputed, to train TURTLE with a single space, run:
python run_turtle.py --dataset ${DATASET} --phis ${REPRESENTATION}
or to train TURTLE with multiple representation spaces, run
python run_turtle.py --dataset ${DATASET} --phis ${REPRESENTATION1} ${REPRESENTATION2}
You can also use --inner_lr, ---outer_lr, --warm_start to specify inner step size, outer step size and whether to use cold-start or warm start bilevel optimization. Furthermore, use --cross_val to compute the generalization score for the found labeling after training. You can perform hyperparameter sweep and use the generalization score to select the best hyperparemeters without using ground truth labels.
Pre-trained Checkpoints
We also release the labelings found by TURTLE for all datasets and all model architectures used in our paper. To download pre-trained checkpoints, run:
wget https://brbiclab.epfl.ch/wp-content/uploads/2024/06/turtle_tasks.zip
unzip turtle_tasks.zip
Then, you can evaluate the pre-trained checkpoint of TURTLE with the single space by running:
python evaluate.py --dataset cifar100 --phis clipvitL14 --task_ckpt {PATH_TO_TURTLE_TASKS}/1space/clipvitL14/cifar100.pt
python evaluate.py --dataset cifar100 --phis dinov2 --task_ckpt {PATH_TO_TURTLE_TASKS}/1space/dinov2/cifar100.pt
or evaluate using two representation spaces using:
python evaluate.py --dataset cifar100 --phis clipvitL14 dinov2 --task_ckpt {PATH_TO_TURTLE_TASKS}/2space/clipvitL14_dinov2/cifar100.pt
Baselines
We also provide implemetation of Zero-shot Transfer with CLIP, Linear Probe and K-Means baselines in the baselines folder. To implement linear probe and K-Means baselines we employ cuml for highly efficient cuda implementations.
Linear Probe
Precompute the representations and then perform linear probe evaluation by running:
python baselines/linear_probe.py --dataset ${DATASET} --phis ${REPRESENTATION}
To select the l2 regularization strength for better performance, run
python baselines/linear_probe.py --dataset ${DATASET} --phis ${REPRESENTATION} --validation
K-Means
Precompute the representations and run K-Means baseline:
python baselines/kmeans.py --dataset ${DATASET} --phis ${REPRESENTATION}
Zero-shot Transfer
Run CLIP zero-shot transfer:
python baselines/clip_zs.py --dataset ${DATASET} --phis ${REPRESENTATION}
Acknowledgements
While developing TURTLE we greatly benefited from the open-source repositories:
Citing
If you find our code useful, please consider citing:
@inproceedings{
gadetsky2024let,
title={Let Go of Your Labels with Unsupervised Transfer},
author={Gadetsky, Artyom and Jiang, Yulun and Brbi\'c, Maria},
booktitle={International Conference on Machine Learning},
year={2024},
}