ijepa icon indicating copy to clipboard operation
ijepa copied to clipboard

How to load ijepa checkpoints?

Open namrahrehman opened this issue 1 year ago • 6 comments

I am trying to use this model for classification of cifar10 in Google Colab. I was trying to load the model to study its layers so I cloned this repo and I am using it as follows:

from vision_transformer import vit_huge
# Initialize the ViT-H model with the specified patch size and resolution
model = vit_huge(patch_size=14, num_classes=1000)  # Adjust num_classes if needed
import torch
# Load the state dictionary from the file
state_dict = torch.load('/content/drive/MyDrive/IN1K-vit.h.14-300e.pth.tar')

# Load the state dictionary into the model
model.load_state_dict(state_dict)

# Print the layers/modules of the model for inspection
def print_model_layers(model, prefix=""):
    for name, module in model.named_children():
        if isinstance(module, torch.nn.Module):
            module_name = prefix + "." + name if prefix else name
            print(module_name)
            print_model_layers(module, prefix=module_name)

print_model_layers(model)

but I get the following error:

`RuntimeError Traceback (most recent call last) in <cell line: 6>() 4 5 # Load the state dictionary into the model ----> 6 model.load_state_dict(state_dict) 7 8 # Print the layers/modules of the model

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict) 2039 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042 self.class.name, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for VisionTransformer: Missing key(s) in state_dict: "pos_embed", "patch_embed.proj.weight", "patch_embed.proj.bias", "blocks.0.norm1.weight", "blocks.0.norm1.bias", "blocks.0.attn.qkv.weight", "blocks.0.attn.qkv.bias", "blocks.0.attn.proj.weight", "blocks.0.attn.proj.bias", "blocks.0.norm2.weight", "blocks.0.norm2.bias", "blocks.0.mlp.fc1.weight", "blocks.0.mlp.fc1.bias", "blocks.0.mlp.fc2.weight", "blocks.0.mlp.fc2.bias", "blocks.1.norm1.weight", "blocks.1.norm1.bias", "blocks.1.attn.qkv.weight", "blocks.1.attn.qkv.bias", "blocks.1.attn.proj.weight", "blocks.1.attn.proj.bias", "blocks.1.norm2.weight", "blocks.1.norm2.bias", "blocks.1.mlp.fc1.weight", "blocks.1.mlp.fc1.bias", "blocks.1.mlp.fc2.weight", "blocks.1.mlp.fc2.bias", "blocks.2.norm1.weight", "blocks.2.norm1.bias", "blocks.2.attn.qkv.weight", "blocks.2.attn.qkv.bias", "blocks.2.attn.proj.weight", "blocks.2.attn.proj.bias", "blocks.2.norm2.weight", "blocks.2.norm2.bias", "blocks.2.mlp.fc1.weight", "blocks.2.mlp.fc1.bias", "blocks.2.mlp.fc2.weight", "blocks.2.mlp.fc2.bias", "blocks.3.norm1.weight", "blocks.3.norm1.bias", "blocks.3.attn.qkv.weight", "blocks.3.attn.qkv.bias", "blocks.3.attn.proj.weight", "blocks.3.attn.proj.bias", "blocks.3.norm2.weight", "blocks.3.norm2.bias", "blocks.3.mlp.fc1.weight", "blocks.3.mlp.fc1.bias", "blocks.3.mlp.fc2.weight", "blocks.3.mlp.fc2.bias", "blocks.4.norm1.weight", "blocks.4.norm1.bias", "blocks.4.attn.qkv.weight", "blocks.4.attn.qkv.bias", "blocks.4.attn.proj.weight", "blocks.4.attn.proj.bias", "bl... Unexpected key(s) in state_dict: "encoder", "predictor", "opt", "scaler", "target_encoder", "epoch", "loss", "batch_size", "world_size", "lr".`

I do not understand which vit from the vision_tranformer.py I am supposed to use for the checkpoint (IN1K-vit.h.14-300e.pth.tar) because using vit_huge gives the error above.

namrahrehman avatar Oct 10 '23 13:10 namrahrehman

It seems that the checkpoint is saved as a DDP module, but you tried to load it into a pure encoder. This can be the solution.

ckpt = torch.load(load_path, map_location=torch.device('cpu'))
pretrained_dict = ckpt['encoder']

# -- loading encoder
for k, v in pretrained_dict.items():
  encoder.state_dict()[k[len("module."):]].copy_(v)

CUN-bjy avatar Oct 17 '23 13:10 CUN-bjy

Thank you so much @CUN-bjy, it worked! However, I couldn't classify the images due to limited computing resources. Thanks again for your help!

namrahrehman avatar Oct 17 '23 17:10 namrahrehman

Hello everyone. Any insights as to how one can take a checkpoint/pretrained model and use it for some downstream task? As in, load the already trained weights into a model, freeze them and use this to train a classifier for another dataset (e.g. CIFAR 10).

Also, what is the complete answer to the question posed above? For example, where does the encoder variable come from? A complete code snippet would be of great help.

I've figured the steps for loading the checkpoint are the following:

  • Take the state_dict
  • Initialize the corresponding ViT (e.g. ViT-H with the init_model function from src.helper.py)
  • Initialize an optimizer with init_opt
  • Then? Which parts of the IJEPA architecture are needed to utilize the embeddings in some other task as described earlier?

This is for research purposes by myself, an undergrad.

Thank you in advance!

lazarosgogos avatar Mar 28 '24 15:03 lazarosgogos

Hello everyone. Any insights as to how one can take a checkpoint/pretrained model and use it for some downstream task? As in, load the already trained weights into a model, freeze them and use this to train a classifier for another dataset (e.g. CIFAR 10).

Also, what is the complete answer to the question posed above? For example, where does the encoder variable come from? A complete code snippet would be of great help.

I've figured the steps for loading the checkpoint are the following:

  • Take the state_dict
  • Initialize the corresponding ViT (e.g. ViT-H with the init_model function from src.helper.py)
  • Initialize an optimizer with init_opt
  • Then? Which parts of the IJEPA architecture are needed to utilize the embeddings in some other task as described earlier?

This is for research purposes by myself, an undergrad.

Thank you in advance!

This would be great to have a solution on if someone has managed to get it working!

lange4531 avatar Apr 10 '24 14:04 lange4531

Hello everyone. Any insights as to how one can take a checkpoint/pretrained model and use it for some downstream task? As in, load the already trained weights into a model, freeze them and use this to train a classifier for another dataset (e.g. CIFAR 10).

Also, what is the complete answer to the question posed above? For example, where does the encoder variable come from? A complete code snippet would be of great help.

I've figured the steps for loading the checkpoint are the following:

  • Take the state_dict
  • Initialize the corresponding ViT (e.g. ViT-H with the init_model function from src.helper.py)
  • Initialize an optimizer with init_opt
  • Then? Which parts of the IJEPA architecture are needed to utilize the embeddings in some other task as described earlier?

This is for research purposes by myself, an undergrad.

Thank you in advance!

You can take pretrained Target Encoder and finetune on your custom datasets. But finetuning would be costly as you can see from the size of encoder: It has 32 blocks as Vit based models require lot of data to be tuned for the task at hand. Also GPU requirement is higher. One possibility would be training a MLP (1 layer, 2 layers, ....N layers) on top the encoder for task of interest.

Possible downstream tasks would be image similarity, classification, etc. Feature extraction is the main component, you can use it anywhere!

VimukthiRandika1997 avatar Apr 10 '24 14:04 VimukthiRandika1997

Hello everyone. Any insights as to how one can take a checkpoint/pretrained model and use it for some downstream task? As in, load the already trained weights into a model, freeze them and use this to train a classifier for another dataset (e.g. CIFAR 10).

Also, what is the complete answer to the question posed above? For example, where does the encoder variable come from? A complete code snippet would be of great help.

I've figured the steps for loading the checkpoint are the following:

* Take the state_dict

* Initialize the corresponding ViT (e.g. ViT-H with the `init_model` function from `src.helper.py`)

* Initialize an optimizer with `init_opt`

* Then? Which parts of the IJEPA architecture are needed to utilize the embeddings in some other task as described earlier?

This is for research purposes by myself, an undergrad.

Thank you in advance!

I have developed a fine-tuning code for the I-JEPA here very based on the ViT-MAE in order to reproduce the experiments conducted here right now it's seeming to work, as the loss is decreasing, but I'm not managing to get much reduction on the test error so I am currently investigating that. If you need help contact me on discord (at falsomoralista) or something.

FalsoMoralista avatar Apr 11 '24 22:04 FalsoMoralista