vision icon indicating copy to clipboard operation
vision copied to clipboard

Simplify transfer learning by modifying get_model()

Open david-csnmedia opened this issue 1 year ago • 2 comments

🚀 The feature

Currently torchvision.models.get_model() doesn't allow you to build a model architecture with a different number of classes and keep existing pre-trained weights backbone for certain types (namely Image Classification models like EfficientNet).

Could something like this be incorporated into the get_model() method, or could another method be created to accommodate?


model = torchvision.models.get_model(self.model_type, weights=self.weights_backbone)

# fix the in/out features of the final layer of the classifier to match num_classes. 
# We have to do this after get_model() so we can retain the pre-trained weights, but 
# modify the model architecture for our use case.

classifier_layer = model.classifier
last_layer_index = len(classifier_layer) - 1

original_linear_layer = classifier_layer[last_layer_index]

new_linear_layer = torch.nn.Linear(in_features=original_linear_layer.in_features, out_features=self.num_classes)
classifier_layer[last_layer_index] = new_linear_layer

Motivation, pitch

Raising an error about the backbone weights having a mismatch guides users in a direction that isn't helpful.

Alternatives

No response

Additional context

No response

david-csnmedia avatar Sep 03 '24 21:09 david-csnmedia

Hi @david-csnmedia ,

that kind of model surgery is probably too specific to each model for it to be reliably implemented within get_model(). Note that some model builders allow num_classes to be passed.

NicolasHug avatar Sep 04 '24 08:09 NicolasHug

I believe the simplest method right now would be to load the model without pretrained weights first, then manually load in the state_dict after removing the classifier from the state_dict. You may want to manually verify the weights first (for example intentionally throw an error with strict=True for unexpected/missing keys) as strict=False removes all validation. Here's an example:

from torchvision.models import get_model, get_model_weights
model_name = 'efficientnet_v2_s'
model = get_model(model_name, weights=None, num_classes=2)
model_weights = get_model_weights(model_name).DEFAULT.get_state_dict(progress=True, check_hash=True)
model_weights = {k: v for k,v in model_weights.items() if not k.startswith('classifier')}
model.load_state_dict(model_weights, strict=False)

Fredrik00 avatar Oct 20 '25 14:10 Fredrik00