ijepa icon indicating copy to clipboard operation
ijepa copied to clipboard

Struggling to replicate evaluation results

Open finbarrtimbers opened this issue 1 year ago • 14 comments

Hi folks,

I'm trying to replicate your linear probe evaluation results. I can only get your pre-trained model to score 77% (with the last layer) or 80.8% (with the last 4 layers) on a linear probe of CIFAR-100. I'm using the ViT-H with a patch size of 14 that was trained for 300 epochs on ImageNet-1k.

I am using the following transforms, taken from the VISSL repo. Training:

# Taken from the VISSL repo.
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(size=224, interpolation=3),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
])

And testing:

test_transforms = transforms.Compose([
    transforms.Resize(size=256, interpolation=3),
    transforms.CenterCrop(size=224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                        std=[0.229, 0.224, 0.225]),
    ])    

I'm training a linear model using SGD w/ parameters lr=0.01, momentum=0.9, weight_decay=5e-4, nesterov=True, and I'm decaying the learning rate by a factor of 10 on the 8th, 16th, and 24th epochs.

To get the "last four" embedding, I'm concatenating the output of the last four blocks to create a (B, 1024, 1280) tensor.

For the averaging, I'm averaging over the spatial dimension, i.e. turning the (B, S, 1280) tensor into a (B, 1280) tensor.

Would you be able to shed any light on what I could be missing in my reproduction?

finbarrtimbers avatar Aug 23 '23 17:08 finbarrtimbers

Have you tried averaging the last 4 layers along the spatial dimension and then concatenating them in the channel dimension? The wording in the paper is ambiguous, but I can read it as the concatenation takes place after averaging the last 4 layers, and since it doesnt mention another averaging it makes the most sense to me that they concatenated along the channel dimension here.

ryang555 avatar Sep 08 '23 19:09 ryang555

So, you think I need to merge these to create a (B, 256, 5120) tensor, and then average the spatial dim to create a (B, 5120) tensor?

finbarrtimbers avatar Sep 08 '23 21:09 finbarrtimbers

Thats one way. I was thinking average them before the concatenation, so it would be a [B,1,5120] (so 4 [B,1,1280] concatted)…im thinking its a possibility with how its worded. I personally wouldve done it the same way you did, however it can be interpreted that way.

Another possibility is that data2vec2.0 does a similar last layer averaging but they layernorm each of the layers before concatenating and then again after averaging. So possibly the paper forgot to mention some layer normalizations before the concat.

ryang555 avatar Sep 08 '23 21:09 ryang555

What's the practical difference between averaging them before concatenating vs averaging them after?

finbarrtimbers avatar Sep 08 '23 21:09 finbarrtimbers

Well the main difference is the channel that is being concatenated. Averaging before the concatenation (assuming no other operations were applied afterwords) would mean they concatenated along the channel dimension, which also means that the linear projection layer they used for probing is larger.

ryang555 avatar Sep 08 '23 21:09 ryang555

Any further luck in replicating the results @finbarrtimbers ?

vedal avatar Feb 24 '24 23:02 vedal

I am also curious about this, I have not been able to implement it with the same success either.

ChristopherMarais avatar Mar 19 '24 01:03 ChristopherMarais

I was able to implement to some degree of success. I used the concatenation mentioned above where I concatenate the last 4 outputs along the channel dimension and average across the token dimension. Additionally the LR needs to be 0.05. This was the largest factor as no other LR mentioned in the paper turned out successful for me.

ryang555 avatar Mar 19 '24 01:03 ryang555

No luck. I gave up.

On Mon, Mar 18, 2024 at 19:23 Chris @.***> wrote:

I am also curious about this, I have not been able to implement it with the same success either.

— Reply to this email directly, view it on GitHub https://github.com/facebookresearch/ijepa/issues/47#issuecomment-2005561459, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAYN6RP4XVJEFT2YXQJJKZTYY6HPXAVCNFSM6AAAAAA332S7V2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMBVGU3DCNBVHE . You are receiving this because you were mentioned.Message ID: @.***>

finbarrtimbers avatar Mar 19 '24 01:03 finbarrtimbers

Do you have an scripts of how to evaluate this model? I would love not to waste time on figuring out something that others have figured out before.

ChristopherMarais avatar Mar 19 '24 01:03 ChristopherMarais

This script can be used if you just swap the models. Gives very similar results to paper for pretrained vit-h14 https://github.com/facebookresearch/mae/blob/main/main_linprobe.py

vedal avatar Mar 19 '24 08:03 vedal

@vedal Wondering if you could share your code for this? Would really appreciate it.

BrandonMan123 avatar Mar 22 '24 21:03 BrandonMan123

This script can be used if you just swap the models. Gives very similar results to paper for pretrained vit-h14 https://github.com/facebookresearch/mae/blob/main/main_linprobe.py

I'm not getting great results using this script. the accuracies seem low. Not sure if it is a hyperparameter problem or something else.

ChristopherMarais avatar Mar 25 '24 02:03 ChristopherMarais

@ryang555, @finbarrtimbers this is how I'm doing. I sense that the training loss is reducing very slowly in comparison to MAE pretraining which I find rather confusing and makes me skeptical but I will wait on more epochs to see how it behaves.

Start lr: 5.5e-04 Lr: 9.0e-04 Final lr: 5.0e-04


class FinetuningModel(nn.Module):
    def __init__(self, pretrained_model, drop_path, nb_classes):
        super(FinetuningModel, self).__init__()        
        self.pretrained_model = pretrained_model
        
        self.drop_path = drop_path
        self.nb_classes = nb_classes
        
        self.pretrained_model.drop_path = 0.2  
        self.pretrained_model.drop_rate = 0.25
        
        self.n_intermediate_outputs = 4
        

        self.average_pool = nn.AvgPool1d((self.pretrained_model.patch_embed.num_patches), stride=1)

        self.head_drop = nn.Dropout(drop_path)

        self.mlp_head = nn.Linear(self.n_intermediate_outputs * self.pretrained_model.embed_dim,
                                    self.nb_classes)   

    '''
        " Because our I-JEPA implementation uses Vision Transformer architectures without a [cls] token, 
        we adapt the default VISSL evaluation recipe to utilize the average-pooled patch representation
        instead of the [cls] token. 
        We therefore report the best linear evaluation number among the following representations: 
        1) the average-pooled patch representation of the last layer,
        2) the concatenation of the last 4 layers of the average-pooled patch representations."    
    '''
    def get_n_intermediate_outputs(self, n , x):

        # -- patchify x
        x = self.pretrained_model.patch_embed(x) # -> (B, 256, 1280)

        # -- add positional embedding to x 
        pos_embed = self.pretrained_model.interpolate_pos_encoding(x, self.pretrained_model.pos_embed) # See vision_transformer.py @Line 410 for more info.
        x = x + pos_embed
        
        # Extract the representation (B, 256, 1280) from the last 4 layers
        # averaging them individually -> 4x (B, 1, 1280), then
        # concatenate into a single representation (B, 5120).
        outputs = []
        n_blocks = len(self.pretrained_model.blocks) - 1            
        layers = [(n_blocks - i) for i in reversed(range(n))] # -> [28, 29, 30, 31]
        # -- 1. fwd prop
        for b, blk in enumerate(self.pretrained_model.blocks):
            x = blk(x)
            # -- 2. Patch-wise averaging and normalization.
            if b in layers:
                h = self.average_pool(x.transpose(1, 2)).transpose(1, 2)
                h = h.squeeze(1) # adjust
                h = F.layer_norm(h, (h.size(-1),)) # Normalize over feature-dim    
                outputs.append(h)
                        
        # -- 3. Concatenation
        output = torch.cat(outputs, dim=-1)
        return output

    def forward(self, x):

        x = self.get_n_intermediate_outputs(self.n_intermediate_outputs, x)
        
        x = self.head_drop(x) # As performed in timm.models
        
        x = self.mlp_head(x)
        return x

Starting train_loss: 2.823 

INFO:root:Epoch 1
INFO:root:avg. train_loss 2.783
INFO:root:avg. test_loss 12.118 avg. Accuracy@1 0.000 - avg. Accuracy@5 0.001

INFO:root:Epoch 2
INFO:root:avg. train_loss 2.710
INFO:root:avg. test_loss 12.414 avg. Accuracy@1 0.004 - avg. Accuracy@5 0.010

INFO:root:Epoch 3
INFO:root:avg. train_loss 2.655
INFO:root:avg. test_loss 12.638 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.009

INFO:root:Epoch 4
INFO:root:avg. train_loss 2.599
INFO:root:avg. test_loss 13.040 avg. Accuracy@1 0.003 - avg. Accuracy@5 0.007

INFO:root:Epoch 5
INFO:root:avg. train_loss 2.539
INFO:root:avg. test_loss 13.141 avg. Accuracy@1 0.002 - avg. Accuracy@5 0.004

INFO:root:Epoch 6
INFO:root:avg. train_loss 2.475
INFO:root:avg. test_loss 13.677 avg. Accuracy@1 0.002 - avg. Accuracy@5 0.009

INFO:root:Epoch 7
Loss: 2.4096567630767822
INFO:root:avg. test_loss 13.900 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.007

INFO:root:Epoch 8
Loss: 2.1995465755462646
INFO:root:avg. test_loss 13.965 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.007

INFO:root:Epoch 9
Loss: 2.3384854793548584
INFO:root:avg. test_loss 14.337 avg. Accuracy@1 0.002 - avg. Accuracy@5 0.009

INFO:root:Epoch 10
Loss: 1.8019111156463623
INFO:root:avg. test_loss 14.660 avg. Accuracy@1 0.000 - avg. Accuracy@5 0.005

INFO:root:Epoch 12
Loss: 2.0356271266937256
INFO:root:avg. test_loss 14.674 avg. Accuracy@1 0.000 - avg. Accuracy@5 0.004

INFO:root:Epoch 13
Loss: 1.7954494953155518
INFO:root:avg. test_loss 14.860 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.002

INFO:root:Epoch 14
Loss: 2.2272109985351562
INFO:root:avg. test_loss 14.691 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.006

INFO:root:Epoch 15
Loss: 1.9833906888961792
INFO:root:avg. test_loss 14.710 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.005

INFO:root:Epoch 16
Loss: 2.02565336227417
INFO:root:avg. test_loss 14.748 avg. Accuracy@1 0.002 - avg. Accuracy@5 0.007

INFO:root:Epoch 17
Loss: 2.025174140930176
INFO:root:avg. test_loss 15.031 avg. Accuracy@1 0.000 - avg. Accuracy@5 0.008

INFO:root:Epoch 18
Loss: 2.100982666015625
INFO:root:avg. test_loss 14.659 avg. Accuracy@1 0.000 - avg. Accuracy@5 0.005

INFO:root:Epoch 19
Loss: 1.631501317024231
INFO:root:avg. test_loss 14.682 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.005

INFO:root:Epoch 20
Loss: 1.408584475517273
INFO:root:avg. test_loss 14.592 avg. Accuracy@1 0.000 - avg. Accuracy@5 0.003

INFO:root:Epoch 21
Loss: 1.44764244556427
INFO:root:avg. test_loss 14.704 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.010

INFO:root:Epoch 22
Loss: 1.7822728157043457
INFO:root:avg. test_loss 14.704 avg. Accuracy@1 0.001 - avg. Accuracy@5 0.006

FalsoMoralista avatar Apr 21 '24 17:04 FalsoMoralista