papernotes
papernotes copied to clipboard
Selfie: Self-supervised Pretraining for Image Embedding
trafficstars
Metadata
- Authors: Trieu H. Trinh, Minh-Thang Luong, Quoc V. Le
- Organization: Google Brain
- Paper: https://arxiv.org/abs/1906.02940
TL;DR
- This paper aims to translate the success of language model pre-training from texts to images by proposing the BERT-like self-supervised learning method called Selfie.
- Selfie combines BERT and CPC (contrastive predictive coding) loss which is novel.
- The related work is well covered.
Method

- 2 stage: pre-training (focus of this work) and fine-tuning.
- Pre-train the first 3 blocks of ResNet-50, and fine-tune the whole ResNet-50.
- High-level idea:
- (1) Given an image, split it to 3x3 (for example) patches.
- (2) Randomly masked out 3 patches, for example, 3th, 4th, 8th patch.
- (3) Task: Given context patches (1th, 2th, 5th, 6th, 7th, 9th), predict what patched is being masked out (predict one at a time).
- (4) The prediction is formed as a classification task instead of regression (i.e., generate a patch) since regression is sensitive to small changes in the image.
- Model detail:
- Patch processing network (Pnet): The first 3 blocks of ResNet + average pooling, aiming to encode 1-9th patches independently. Now we have 9 feature vectors.
- Encoder: Encode (1th, 2th, 5th, 6th, 7th, 9th) feature vectors to a "context vector (u)" with an "Attention Pooling Layer".
- Choose which location of masked out patch to predict: Add the "position embedding" of a masked out patch (e.g., 4th) with the context vector (u), we get a "query vector (v)".
- (Pointer-based) Decoder: Point to the masked out patch (e.g., 4th) based on the query vector (v) i.e., compute the dot product (similarity) between each pair <i-th feature vector, v> (i=3, 4, 8), and maximize the similarity of <4th feature vector, v> by cross-entropy loss.
- Attention Pooling Layer:
- Can be think of a generalized average/max pooling layer.
- Here they use Transformer layers for pooling.
- The attention blocks follow the self-attention in BERT.
- Positional embedding: Decomposed to row and col embeddings and sum together.
Experiment Setting
- Dataset: CIFAR-10, ImageNet 32x32 and 224x224.
- Patch size: 8x8 for 32x32 and 32x32 for 224x224.
- Full data (50K for CIFAR-10 & 1.2M for ImageNet) for pre-training stage.
- Split data to 5%, 10%, 20%, 100% labeled data for fine-tuning stage.
- Replace CIFAR-10 10% data with 8% (4000 examples) data following AutoAugment and Realistic Eval of SSL papers.
- They also describe model training and hyper-parameters in detail.
- Train model for 120K steps.
- Baseline ResNet achieves strong accuracy 95.5% on CIFAR-10 and 76.9% on ImageNet 224x224.
Key Experiment Results and Findings
- Selfie pre-training helps regularization when labeled data is 10%.
- Selfie pre-training improves all performance when labeled data is 5%~100%.
- Pre-training benefits more when there is less labeled data.
- Self-attention as the last layer helps fine-tuning performance.
Fine-tuning Sensitivity and Mismatch to Pre-training.
- There are difficulties in transferring pre-trained model across tasks, e.g., from ImageNet to CIFAR-10.
- For the 100% subset of ImageNet 224x224, additional tuning of the pre-training phase using a development set is needed to achieve the reported result.
- Input mismatch exists between pre-training (the Pnet only sees the patches) and fine-tuning (the Pnet sees the whole image).
Personal Thoughts
- I am wondering that whether the patch size is too small for ResNet-based Pnet to encode, especially 8x8 patch.
- The attention pooling layer and positional embedding is not clear here, may refer to the BERT paper.
- Where is CPC loss? The paper only describes cross-entropy loss. May also need to refer to CPC paper.
- The hyper-parameter selection in Reporting results section is not clear to me. Is it training on pretraining stage or finetuning stage?
- The performance gain from SSL becomes very marginal when labeled data becomes more, especially if task is easy (CIFAR-10's gain is smaller than ImageNet).
- The challenge of making SSL more useful (i.e., acquire more performance gain in full labeled data) still exists.
- Although they didn't compare to other SSL approach (e.g., for CPC, they said that CPC uses ResNet-171), they emphasize that their baseline ResNet is stronger than previous ones, which hope to make the experiment results more convincible to the reader.
- They did not show whether the effect of regularization still stands if we have full labeled data.
- There's a semi-supervised learning approach called UDA achieves 94.7% accuracy on the 8% CIFAR-10, while this paper only achieves 80.3% accuracy, suggesting that maybe semi-supervised learning is more promising than self-supervised learning (?). Or maybe they can be combined with each other, as shown in the S^4L paper.
- Other worth-noting author: Avital Oliver (author of MixMatch and S^4L,..) and Xiaohua Zhai (author of Revisiting SSL and S^4L). Lets read their papers!