pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

[BUG] Loading state dict in a feature extraction network

Open ioangatop opened this issue 1 year ago • 1 comments

Describe the bug

Hi Ross! I'm facing a small issue with the features extractor, here are some details:

The function create_model supports the argument of checkpoint_path which allows to load custom model weights. However, when we want to load a model as feature extractor, the model is wrapped around the FeatureGetterNet class, and the loading fails as the keys do not much anymore; the FeatureGetterNet stores the model under self.model so in order to work, the state dict keys should have a prefix model., for example class_token -> model.class_token

Additionally, one workaround is to do the loading of the model after the initialisation, but this also fails as some networks, like vision transformer, prune some layers and thus the state_dict has extra keys

To Reproduce

from urllib import request

from timm.models import _helpers
import timm


# download weights
request.urlretrieve("https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth", "dino_deitsmall16_pretrain.pth")

# build and load model -- works as expected
model = timm.create_model(
    model_name="vit_small_patch16_224",
    num_classes=0,
    checkpoint_path="dino_deitsmall16_pretrain.pth",
)

# RuntimeError: Error(s) in loading state_dict for FeatureGetterNet:
#   Missing key(s) in state_dict: "model.cls_token", "model.pos_embed", ...
#  Unexpected key(s) in state_dict: "cls_token", "pos_embed", ...
backbone = timm.create_model(
    model_name="vit_small_patch16_224",
    num_classes=0,
    features_only=True,
    checkpoint_path="dino_deitsmall16_pretrain.pth",
)

# RuntimeError: Error(s) in loading state_dict for VisionTransformer:
#   Unexpected key(s) in state_dict: "norm.weight", "norm.bias". 
backbone = timm.create_model(
    model_name="vit_small_patch16_224",
    num_classes=0,
    features_only=True,
)
_helpers.load_checkpoint(backbone.model, "dino_deitsmall16_pretrain.pth")

As always, thanks a lot 🙏

ioangatop avatar Jun 26 '24 13:06 ioangatop