ijepa
ijepa copied to clipboard
How to load ijepa checkpoints?
I am trying to use this model for classification of cifar10 in Google Colab. I was trying to load the model to study its layers so I cloned this repo and I am using it as follows:
from vision_transformer import vit_huge
# Initialize the ViT-H model with the specified patch size and resolution
model = vit_huge(patch_size=14, num_classes=1000) # Adjust num_classes if needed
import torch
# Load the state dictionary from the file
state_dict = torch.load('/content/drive/MyDrive/IN1K-vit.h.14-300e.pth.tar')
# Load the state dictionary into the model
model.load_state_dict(state_dict)
# Print the layers/modules of the model for inspection
def print_model_layers(model, prefix=""):
for name, module in model.named_children():
if isinstance(module, torch.nn.Module):
module_name = prefix + "." + name if prefix else name
print(module_name)
print_model_layers(module, prefix=module_name)
print_model_layers(model)
but I get the following error:
`RuntimeError Traceback (most recent call last)
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict) 2039 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042 self.class.name, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for VisionTransformer: Missing key(s) in state_dict: "pos_embed", "patch_embed.proj.weight", "patch_embed.proj.bias", "blocks.0.norm1.weight", "blocks.0.norm1.bias", "blocks.0.attn.qkv.weight", "blocks.0.attn.qkv.bias", "blocks.0.attn.proj.weight", "blocks.0.attn.proj.bias", "blocks.0.norm2.weight", "blocks.0.norm2.bias", "blocks.0.mlp.fc1.weight", "blocks.0.mlp.fc1.bias", "blocks.0.mlp.fc2.weight", "blocks.0.mlp.fc2.bias", "blocks.1.norm1.weight", "blocks.1.norm1.bias", "blocks.1.attn.qkv.weight", "blocks.1.attn.qkv.bias", "blocks.1.attn.proj.weight", "blocks.1.attn.proj.bias", "blocks.1.norm2.weight", "blocks.1.norm2.bias", "blocks.1.mlp.fc1.weight", "blocks.1.mlp.fc1.bias", "blocks.1.mlp.fc2.weight", "blocks.1.mlp.fc2.bias", "blocks.2.norm1.weight", "blocks.2.norm1.bias", "blocks.2.attn.qkv.weight", "blocks.2.attn.qkv.bias", "blocks.2.attn.proj.weight", "blocks.2.attn.proj.bias", "blocks.2.norm2.weight", "blocks.2.norm2.bias", "blocks.2.mlp.fc1.weight", "blocks.2.mlp.fc1.bias", "blocks.2.mlp.fc2.weight", "blocks.2.mlp.fc2.bias", "blocks.3.norm1.weight", "blocks.3.norm1.bias", "blocks.3.attn.qkv.weight", "blocks.3.attn.qkv.bias", "blocks.3.attn.proj.weight", "blocks.3.attn.proj.bias", "blocks.3.norm2.weight", "blocks.3.norm2.bias", "blocks.3.mlp.fc1.weight", "blocks.3.mlp.fc1.bias", "blocks.3.mlp.fc2.weight", "blocks.3.mlp.fc2.bias", "blocks.4.norm1.weight", "blocks.4.norm1.bias", "blocks.4.attn.qkv.weight", "blocks.4.attn.qkv.bias", "blocks.4.attn.proj.weight", "blocks.4.attn.proj.bias", "bl... Unexpected key(s) in state_dict: "encoder", "predictor", "opt", "scaler", "target_encoder", "epoch", "loss", "batch_size", "world_size", "lr".`
I do not understand which vit from the vision_tranformer.py I am supposed to use for the checkpoint (IN1K-vit.h.14-300e.pth.tar) because using vit_huge gives the error above.
It seems that the checkpoint is saved as a DDP module, but you tried to load it into a pure encoder. This can be the solution.
ckpt = torch.load(load_path, map_location=torch.device('cpu'))
pretrained_dict = ckpt['encoder']
# -- loading encoder
for k, v in pretrained_dict.items():
encoder.state_dict()[k[len("module."):]].copy_(v)
Thank you so much @CUN-bjy, it worked! However, I couldn't classify the images due to limited computing resources. Thanks again for your help!
Hello everyone. Any insights as to how one can take a checkpoint/pretrained model and use it for some downstream task? As in, load the already trained weights into a model, freeze them and use this to train a classifier for another dataset (e.g. CIFAR 10).
Also, what is the complete answer to the question posed above? For example, where does the encoder
variable come from? A complete code snippet would be of great help.
I've figured the steps for loading the checkpoint are the following:
- Take the state_dict
- Initialize the corresponding ViT (e.g. ViT-H with the
init_model
function fromsrc.helper.py
) - Initialize an optimizer with
init_opt
- Then? Which parts of the IJEPA architecture are needed to utilize the embeddings in some other task as described earlier?
This is for research purposes by myself, an undergrad.
Thank you in advance!
Hello everyone. Any insights as to how one can take a checkpoint/pretrained model and use it for some downstream task? As in, load the already trained weights into a model, freeze them and use this to train a classifier for another dataset (e.g. CIFAR 10).
Also, what is the complete answer to the question posed above? For example, where does the
encoder
variable come from? A complete code snippet would be of great help.I've figured the steps for loading the checkpoint are the following:
- Take the state_dict
- Initialize the corresponding ViT (e.g. ViT-H with the
init_model
function fromsrc.helper.py
)- Initialize an optimizer with
init_opt
- Then? Which parts of the IJEPA architecture are needed to utilize the embeddings in some other task as described earlier?
This is for research purposes by myself, an undergrad.
Thank you in advance!
This would be great to have a solution on if someone has managed to get it working!
Hello everyone. Any insights as to how one can take a checkpoint/pretrained model and use it for some downstream task? As in, load the already trained weights into a model, freeze them and use this to train a classifier for another dataset (e.g. CIFAR 10).
Also, what is the complete answer to the question posed above? For example, where does the
encoder
variable come from? A complete code snippet would be of great help.I've figured the steps for loading the checkpoint are the following:
- Take the state_dict
- Initialize the corresponding ViT (e.g. ViT-H with the
init_model
function fromsrc.helper.py
)- Initialize an optimizer with
init_opt
- Then? Which parts of the IJEPA architecture are needed to utilize the embeddings in some other task as described earlier?
This is for research purposes by myself, an undergrad.
Thank you in advance!
You can take pretrained Target Encoder
and finetune on your custom datasets. But finetuning would be costly as you can see from the size of encoder: It has 32 blocks as Vit based models require lot of data to be tuned for the task at hand. Also GPU requirement is higher. One possibility would be training a MLP (1 layer, 2 layers, ....N layers) on top the encoder for task of interest.
Possible downstream tasks would be image similarity, classification, etc. Feature extraction is the main component, you can use it anywhere!
Hello everyone. Any insights as to how one can take a checkpoint/pretrained model and use it for some downstream task? As in, load the already trained weights into a model, freeze them and use this to train a classifier for another dataset (e.g. CIFAR 10).
Also, what is the complete answer to the question posed above? For example, where does the
encoder
variable come from? A complete code snippet would be of great help.I've figured the steps for loading the checkpoint are the following:
* Take the state_dict * Initialize the corresponding ViT (e.g. ViT-H with the `init_model` function from `src.helper.py`) * Initialize an optimizer with `init_opt` * Then? Which parts of the IJEPA architecture are needed to utilize the embeddings in some other task as described earlier?
This is for research purposes by myself, an undergrad.
Thank you in advance!
I have developed a fine-tuning code for the I-JEPA here very based on the ViT-MAE in order to reproduce the experiments conducted here right now it's seeming to work, as the loss is decreasing, but I'm not managing to get much reduction on the test error so I am currently investigating that. If you need help contact me on discord (at falsomoralista) or something.