Simplify transfer learning by modifying get_model()
🚀 The feature
Currently torchvision.models.get_model() doesn't allow you to build a model architecture with a different number of classes and keep existing pre-trained weights backbone for certain types (namely Image Classification models like EfficientNet).
Could something like this be incorporated into the get_model() method, or could another method be created to accommodate?
model = torchvision.models.get_model(self.model_type, weights=self.weights_backbone)
# fix the in/out features of the final layer of the classifier to match num_classes.
# We have to do this after get_model() so we can retain the pre-trained weights, but
# modify the model architecture for our use case.
classifier_layer = model.classifier
last_layer_index = len(classifier_layer) - 1
original_linear_layer = classifier_layer[last_layer_index]
new_linear_layer = torch.nn.Linear(in_features=original_linear_layer.in_features, out_features=self.num_classes)
classifier_layer[last_layer_index] = new_linear_layer
Motivation, pitch
Raising an error about the backbone weights having a mismatch guides users in a direction that isn't helpful.
Alternatives
No response
Additional context
No response
Hi @david-csnmedia ,
that kind of model surgery is probably too specific to each model for it to be reliably implemented within get_model().
Note that some model builders allow num_classes to be passed.
I believe the simplest method right now would be to load the model without pretrained weights first, then manually load in the state_dict after removing the classifier from the state_dict. You may want to manually verify the weights first (for example intentionally throw an error with strict=True for unexpected/missing keys) as strict=False removes all validation. Here's an example:
from torchvision.models import get_model, get_model_weights
model_name = 'efficientnet_v2_s'
model = get_model(model_name, weights=None, num_classes=2)
model_weights = get_model_weights(model_name).DEFAULT.get_state_dict(progress=True, check_hash=True)
model_weights = {k: v for k,v in model_weights.items() if not k.startswith('classifier')}
model.load_state_dict(model_weights, strict=False)