ijepa
ijepa copied to clipboard
Struggling to replicate evaluation results
Hi folks,
I'm trying to replicate your linear probe evaluation results. I can only get your pre-trained model to score 77% (with the last layer) or 80.8% (with the last 4 layers) on a linear probe of CIFAR-100. I'm using the ViT-H with a patch size of 14 that was trained for 300 epochs on ImageNet-1k.
I am using the following transforms, taken from the VISSL repo. Training:
# Taken from the VISSL repo.
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(size=224, interpolation=3),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
And testing:
test_transforms = transforms.Compose([
transforms.Resize(size=256, interpolation=3),
transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
I'm training a linear model using SGD w/ parameters lr=0.01, momentum=0.9, weight_decay=5e-4, nesterov=True
, and I'm decaying the learning rate by a factor of 10 on the 8th, 16th, and 24th epochs.
To get the "last four" embedding, I'm concatenating the output of the last four blocks to create a (B, 1024, 1280)
tensor.
For the averaging, I'm averaging over the spatial dimension, i.e. turning the (B, S, 1280)
tensor into a (B, 1280)
tensor.
Would you be able to shed any light on what I could be missing in my reproduction?
Have you tried averaging the last 4 layers along the spatial dimension and then concatenating them in the channel dimension? The wording in the paper is ambiguous, but I can read it as the concatenation takes place after averaging the last 4 layers, and since it doesnt mention another averaging it makes the most sense to me that they concatenated along the channel dimension here.
So, you think I need to merge these to create a (B, 256, 5120)
tensor, and then average the spatial dim to create a (B, 5120)
tensor?
Thats one way. I was thinking average them before the concatenation, so it would be a [B,1,5120] (so 4 [B,1,1280] concatted)…im thinking its a possibility with how its worded. I personally wouldve done it the same way you did, however it can be interpreted that way.
Another possibility is that data2vec2.0 does a similar last layer averaging but they layernorm each of the layers before concatenating and then again after averaging. So possibly the paper forgot to mention some layer normalizations before the concat.
What's the practical difference between averaging them before concatenating vs averaging them after?
Well the main difference is the channel that is being concatenated. Averaging before the concatenation (assuming no other operations were applied afterwords) would mean they concatenated along the channel dimension, which also means that the linear projection layer they used for probing is larger.
Any further luck in replicating the results @finbarrtimbers ?
I am also curious about this, I have not been able to implement it with the same success either.
I was able to implement to some degree of success. I used the concatenation mentioned above where I concatenate the last 4 outputs along the channel dimension and average across the token dimension. Additionally the LR needs to be 0.05. This was the largest factor as no other LR mentioned in the paper turned out successful for me.
No luck. I gave up.
On Mon, Mar 18, 2024 at 19:23 Chris @.***> wrote:
I am also curious about this, I have not been able to implement it with the same success either.
— Reply to this email directly, view it on GitHub https://github.com/facebookresearch/ijepa/issues/47#issuecomment-2005561459, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAYN6RP4XVJEFT2YXQJJKZTYY6HPXAVCNFSM6AAAAAA332S7V2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMBVGU3DCNBVHE . You are receiving this because you were mentioned.Message ID: @.***>
Do you have an scripts of how to evaluate this model? I would love not to waste time on figuring out something that others have figured out before.
This script can be used if you just swap the models. Gives very similar results to paper for pretrained vit-h14 https://github.com/facebookresearch/mae/blob/main/main_linprobe.py
@vedal Wondering if you could share your code for this? Would really appreciate it.
This script can be used if you just swap the models. Gives very similar results to paper for pretrained vit-h14 https://github.com/facebookresearch/mae/blob/main/main_linprobe.py
I'm not getting great results using this script. the accuracies seem low. Not sure if it is a hyperparameter problem or something else.
@ryang555, @finbarrtimbers this is how I'm doing. I sense that the training loss is reducing very slowly in comparison to MAE pretraining which I find rather confusing and makes me skeptical but I will wait on more epochs to see how it behaves.
Start lr: 5.5e-04 Lr: 9.0e-04 Final lr: 5.0e-04
class FinetuningModel(nn.Module):
def __init__(self, pretrained_model, drop_path, nb_classes):
super(FinetuningModel, self).__init__()
self.pretrained_model = pretrained_model
self.drop_path = drop_path
self.nb_classes = nb_classes
self.pretrained_model.drop_path = 0.2
self.pretrained_model.drop_rate = 0.25
self.n_intermediate_outputs = 4
self.average_pool = nn.AvgPool1d((self.pretrained_model.patch_embed.num_patches), stride=1)
self.head_drop = nn.Dropout(drop_path)
self.mlp_head = nn.Linear(self.n_intermediate_outputs * self.pretrained_model.embed_dim,
self.nb_classes)
'''
" Because our I-JEPA implementation uses Vision Transformer architectures without a [cls] token,
we adapt the default VISSL evaluation recipe to utilize the average-pooled patch representation
instead of the [cls] token.
We therefore report the best linear evaluation number among the following representations:
1) the average-pooled patch representation of the last layer,
2) the concatenation of the last 4 layers of the average-pooled patch representations."
'''
def get_n_intermediate_outputs(self, n , x):
# -- patchify x
x = self.pretrained_model.patch_embed(x) # -> (B, 256, 1280)
# -- add positional embedding to x
pos_embed = self.pretrained_model.interpolate_pos_encoding(x, self.pretrained_model.pos_embed) # See vision_transformer.py @Line 410 for more info.
x = x + pos_embed
# Extract the representation (B, 256, 1280) from the last 4 layers
# averaging them individually -> 4x (B, 1, 1280), then
# concatenate into a single representation (B, 5120).
outputs = []
n_blocks = len(self.pretrained_model.blocks) - 1
layers = [(n_blocks - i) for i in reversed(range(n))] # -> [28, 29, 30, 31]
# -- 1. fwd prop
for b, blk in enumerate(self.pretrained_model.blocks):
x = blk(x)
# -- 2. Patch-wise averaging and normalization.
if b in layers:
h = self.average_pool(x.transpose(1, 2)).transpose(1, 2)
h = h.squeeze(1) # adjust
h = F.layer_norm(h, (h.size(-1),)) # Normalize over feature-dim
outputs.append(h)
# -- 3. Concatenation
output = torch.cat(outputs, dim=-1)
return output
def forward(self, x):
x = self.get_n_intermediate_outputs(self.n_intermediate_outputs, x)
x = self.head_drop(x) # As performed in timm.models
x = self.mlp_head(x)
return x
Starting train_loss: 2.823
INFO:root:Epoch 1
INFO:root:avg. train_loss 2.783
INFO:root:avg. test_loss 12.118 avg. Accuracy@1 0.000 - avg. Accuracy@5 0.001
INFO:root:Epoch 2
INFO:root:avg. train_loss 2.710
INFO:root:avg. test_loss 12.414 avg. Accuracy@1 0.004 - avg. Accuracy@5 0.010
INFO:root:Epoch 3
INFO:root:avg. train_loss 2.655
INFO:root:avg. test_loss 12.638 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.009
INFO:root:Epoch 4
INFO:root:avg. train_loss 2.599
INFO:root:avg. test_loss 13.040 avg. Accuracy@1 0.003 - avg. Accuracy@5 0.007
INFO:root:Epoch 5
INFO:root:avg. train_loss 2.539
INFO:root:avg. test_loss 13.141 avg. Accuracy@1 0.002 - avg. Accuracy@5 0.004
INFO:root:Epoch 6
INFO:root:avg. train_loss 2.475
INFO:root:avg. test_loss 13.677 avg. Accuracy@1 0.002 - avg. Accuracy@5 0.009
INFO:root:Epoch 7
Loss: 2.4096567630767822
INFO:root:avg. test_loss 13.900 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.007
INFO:root:Epoch 8
Loss: 2.1995465755462646
INFO:root:avg. test_loss 13.965 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.007
INFO:root:Epoch 9
Loss: 2.3384854793548584
INFO:root:avg. test_loss 14.337 avg. Accuracy@1 0.002 - avg. Accuracy@5 0.009
INFO:root:Epoch 10
Loss: 1.8019111156463623
INFO:root:avg. test_loss 14.660 avg. Accuracy@1 0.000 - avg. Accuracy@5 0.005
INFO:root:Epoch 12
Loss: 2.0356271266937256
INFO:root:avg. test_loss 14.674 avg. Accuracy@1 0.000 - avg. Accuracy@5 0.004
INFO:root:Epoch 13
Loss: 1.7954494953155518
INFO:root:avg. test_loss 14.860 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.002
INFO:root:Epoch 14
Loss: 2.2272109985351562
INFO:root:avg. test_loss 14.691 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.006
INFO:root:Epoch 15
Loss: 1.9833906888961792
INFO:root:avg. test_loss 14.710 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.005
INFO:root:Epoch 16
Loss: 2.02565336227417
INFO:root:avg. test_loss 14.748 avg. Accuracy@1 0.002 - avg. Accuracy@5 0.007
INFO:root:Epoch 17
Loss: 2.025174140930176
INFO:root:avg. test_loss 15.031 avg. Accuracy@1 0.000 - avg. Accuracy@5 0.008
INFO:root:Epoch 18
Loss: 2.100982666015625
INFO:root:avg. test_loss 14.659 avg. Accuracy@1 0.000 - avg. Accuracy@5 0.005
INFO:root:Epoch 19
Loss: 1.631501317024231
INFO:root:avg. test_loss 14.682 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.005
INFO:root:Epoch 20
Loss: 1.408584475517273
INFO:root:avg. test_loss 14.592 avg. Accuracy@1 0.000 - avg. Accuracy@5 0.003
INFO:root:Epoch 21
Loss: 1.44764244556427
INFO:root:avg. test_loss 14.704 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.010
INFO:root:Epoch 22
Loss: 1.7822728157043457
INFO:root:avg. test_loss 14.704 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.006