PoDD
PoDD copied to clipboard
Official PyTorch Implementation for the "Distilling Datasets Into Less Than One Image" paper.
Distilling Datasets Into Less Than One Image
Official PyTorch Implementation for the "Distilling Datasets Into Less Than One Image" paper.
Poster Dataset Distillation (PoDD): We propose PoDD, a new dataset distillation setting for a tiny, under 1 image-per-class (IPC) budget. In this example, the standard method attains an accuracy of 35.5% on CIFAR-100 with approximately 100k pixels, PoDD achieves an accuracy of 35.7% with less than half the pixels (roughly 40k)
Distilling Datasets Into Less Than One Image
Asaf Shul*, Eliahu Horwitz*, Yedid Hoshen
*Equal contribution
https://arxiv.org/abs/2403.12040Abstract: Dataset distillation aims to compress a dataset into a much smaller one so that a model trained on the distilled dataset achieves high accuracy. Current methods frame this as maximizing the distilled classification accuracy for a budget of K distilled images-per-class, where K is a positive integer. In this paper, we push the boundaries of dataset distillation, compressing the dataset into less than an image-per-class. It is important to realize that the meaningful quantity is not the number of distilled images-per-class but the number of distilled pixels-per-dataset. We therefore, propose Poster Dataset Distillation (PoDD), a new approach that distills the entire original dataset into a single poster. The poster approach motivates new technical solutions for creating training images and learnable labels. Our method can achieve comparable or better performance with less than an image-per-class compared to existing methods that use one image-per-class. Specifically, our method establishes a new state-of-the-art performance on CIFAR-10, CIFAR-100, and CUB200 using as little as 0.3 images-per-class.
Poster distillation progress over time followed by a semantic visualization of the distilled classes using a poster of CIFAR-10 with 1 IPC
Project Structure
This project consists of:
-
main.py
- Main entry point (handles user run arguments). -
src/base.py
- Main worker for the distillation process. -
src/PoDD.py
- PoDD implementation using RaT-BPTT as the underlying dataset distillation algorithm. -
src/PoCO.py
- PoCO class ordering strategy implementation, using CLIP text embeddings. -
src/PoDDL.py
- PoDDL soft labeling strategy implementation. -
src/PoDD_utils.py
- Utility functions for PoDD. -
src/data_utils.py
- Utility functions for data handling. -
src/util.py
- General utility functions. -
src/convnet.py
- ConvNet model for the distillation process.
Installation
- Clone the repo:
git clone https://github.com/AsafShul/PoDD
cd PoDD
- Create a new environment with needed libraries from the
environment.yml
file, then activate it:
conda env create -f environment.yml
conda activate podd
Dataset Preparation
This implementation supports the following 4 datasets:
CIFAR-10 and CIFAR-100
Both the CIFAR-10 and CIFAR-100 datasets are built-in and will be downloaded automatically.
CUB200
- Download the data from here
- Extract the dataset into
./datasets/CUB200
Tiny ImageNet
- Download the dataset by running
wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
- Extract the dataset into
./tiny-imagenet-200/tiny-imagenet-200
- Preprocess the validation split of the dataset to fit torchvision's ImageFolder structure. This can be done by running the function
format_tiny_imagenet_val
located in./src/data_utils.py
Running PoDD
The main.py
script is the main script in this project.
Below are examples for running PoDD on CIFAR-10, CIFAR100, CUB200 and Tiny ImageNet datasets for 0.9 IPC.
CIFAR-10
python main.py --name=PoDD-CIFAR10-LT1-90 --distill_batch_size=96 --patch_num_x=16 --patch_num_y=6 --dataset=cifar10 --num_train_eval=8 --update_steps=1 --batch_size=5000 --ddtype=curriculum --cctype=2 --epoch=10000 --test_freq=10 --print_freq=10 --arch=convnet --window=60 --minwindow=0 --totwindow=200 --inner_optim=Adam --outer_optim=Adam --inner_lr=0.001 --lr=0.001 --syn_strategy=flip_rotate --real_strategy=flip_rotate --seed=0 --zca --comp_ipc=1 --class_area_width=32 --class_area_height=32 --poster_width=153 --poster_height=60 --poster_class_num_x=5 --poster_class_num_y=2
CIFAR-100
python main.py --name=PoDD-CIFAR100-LT1-90 --distill_batch_size=50 --patch_num_x=20 --patch_num_y=20 --dataset=cifar100 --num_train_eval=8 --update_steps=1 --batch_size=2000 --ddtype=curriculum --cctype=2 --epoch=10000 --test_freq=10 --print_freq=10 --arch=convnet --window=100 --minwindow=0 --totwindow=300 --inner_optim=Adam --outer_optim=Adam --inner_lr=0.001 --lr=0.001 --syn_strategy=flip_rotate --real_strategy=flip_rotate --seed=0 --zca --comp_ipc=1 --class_area_width=32 --class_area_height=32 --poster_width=303 --poster_height=303 --poster_class_num_x=10 --poster_class_num_y=10 --train_y
CUB200
python main.py --name=PoDD-CUB200-LT1-90 --distill_batch_size=200 --patch_num_x=60 --patch_num_y=30 --dataset=cub-200 --num_train_eval=8 --update_steps=1 --batch_size=3000 --ddtype=curriculum --cctype=2 --epoch=10000 --test_freq=25 --print_freq=10 --arch=convnet --window=60 --minwindow=0 --totwindow=200 --inner_optim=Adam --outer_optim=Adam --inner_lr=0.001 --lr=0.001 --syn_strategy=flip_rotate --real_strategy=flip_rotate --seed=1 --zca --comp_ipc=1 --class_area_width=32 --class_area_height=32 --poster_width=610 --poster_height=302 --poster_class_num_x=20 --poster_class_num_y=10 --train_y
Tiny ImageNet
python main.py --name=PoDD_TinyImageNet-LT1-90 --distill_batch_size=30 --patch_num_x=40 --patch_num_y=20 --dataset=tiny-imagenet-200 --num_train_eval=8 --update_steps=1 --batch_size=500 --ddtype=curriculum --cctype=2 --epoch=10000 --test_freq=5 --print_freq=1 --arch=convnet --window=100 --minwindow=0 --totwindow=300 --inner_optim=Adam --outer_optim=Adam --inner_lr=0.0005 --lr=0.0005 --syn_strategy=flip_rotate --real_strategy=flip_rotate --seed=0 --zca --comp_ipc=1 --class_area_width=64 --class_area_height=64 --poster_width=1211 --poster_height=608 --poster_class_num_x=20 --poster_class_num_y=10 --train_y
Important Hyper-parameters
-
--patch_num_x
and--patch_num_y
- The number of extracted overlapping patches in the x and y axis of the poster. -
--poster_width
and--poster_height
- The width and height of the poster (controls the distillation data budget). -
--poster_class_num_x
and--poster_class_num_y
- The class layout dimensions within the poster as a 2d array (e.g., 10X10 or 20X5), (the product must be equal to the number of classes). -
--train_y
- If set, the model will also optimize a set of learnable labels for the poster.
[!TIP] Increase the
distill_batch_size
andbatch_size
as your GPU memory limitations allow.
Using PoDD with other Dataset Distillation Algorithms
Although we use RaT-BPTT as the underlying distillation algorithm, using PoDD with other dataset distillation algorithms should be straight forward.
The main change is replacing the distillation functionality in src/base.py
and src/PoDD.py
with the desired distillation algorithm.
Citation
If you find this useful for your research, please use the following.
@article{shul2024distilling,
title={Distilling Datasets Into Less Than One Image},
author={Shul, Asaf and Horwitz, Eliahu and Hoshen, Yedid},
journal={arXiv preprint arXiv:2403.12040},
year={2024}
}