image-gpt
image-gpt copied to clipboard
PyTorch Implementation of OpenAI's Image GPT
Image GPT
PyTorch implementation of Image GPT, based on paper Generative Pretraining from Pixels (Chen et al.) and accompanying code.
Model-generated completions of half-images from test set. First column is
input; last column is original image
iGPT-S pretrained on CIFAR10. Completions are fairly poor as the model was only trained on CIFAR10, not all of ImageNet.
WIP
- [ ] Batched k-means on GPU for quantization of larger datasets (currently using
sklearn.cluster.MiniBatchKMeans
.) - [ ] BERT-style pretraining (currently only generative is supported.)
- [ ] Load pretrained models from OpenAI.
- [ ] Reproduce at least iGPT-S results.
According to their blog post, the largest model, iGPT-L (1.4 M parameters), was trained for 2500 V100-days. By greatly reducing the number of attention head, number of layers, and input size (which effects model size quadratically), we can train our own model (26 K parameters) on Fashion-MNIST on a single NVIDIA 2070 in less than 2 hours.
-
Image GPT
-
Usage
- Pre-trained Models
- Compute Centroids
-
Training
- Generative Pre-training
- Classification Fine-tuning
- Sampling
-
Usage
Usage
Pre-trained Models
Some pre-trained models are located in models
directory. Run ./download.sh
to download the cifar10
pretrained iGPT-S model.
Compute Centroids
Images are downloaded, and centroids are computed using k-means with
num_clusters
clusters. These centroids are used to quantize the images before
they are fed into the model.
# options: mnist, fmnist, cifar10
python src/compute_centroids.py --dataset mnist --num_clusters=8
# creates data/<dataset>_centroids.npy
Note: Use the same num_clusters
as num_vocab
in your model.
Training
Models can be trained using src/run.py
with the train
subcommand.
Generative Pre-training
Models can be pretrained by specifying a dataset and model config.
configs/s_gen.yml
corresponds to iGPT-S from the paper, configs/xxs_gen.yml
is an extra small model for trying on toy datasets with limited compute.
python src/run.py --dataset mnist train configs/xxs_gen.yml
Classification Fine-tuning
Pre-trained models can be fine-tuned by passing the path to the pre-trained
checkpoint to --pretrained
, along with the config file and dataset.
python src/run.py --dataset mnist train configs/xxs_clf.yml --pretrained=models/mnist_gen.ckpt`
Sampling
Figures like those seen above can be created using random images from test set:
# outputs to figure.png
python src/sample.py models/mnist_gen.ckpt
Gifs like the one seen in my tweet can be made like so:
# outputs to out.gif
python src/gif.py models/mnist_gen.ckpt