detecto icon indicating copy to clipboard operation
detecto copied to clipboard

Model.load does not support non-default Model types

Open alankbi opened this issue 3 years ago • 1 comments

Describe the bug If someone saves a model file using a non-default model type (Model.MOBILENET or Model.MOBILENET_320) and later tries to load it using Model.load, it will cause the following error:

RuntimeError: Error(s) in loading state_dict for FasterRCNN:

Missing key(s) in state_dict: "backbone.body.conv1.weight", "backbone.body.bn1.weight", "backbone.body.bn1.bias", "backbone.body.bn1.running_mean", "backbone.body.bn1.running_var", "backbone.body.layer1.0.conv1.weight", "backbone.body.layer1.0.bn1.weight", "backbone.body.layer1.0.bn1.bias", "backbone.body.layer1.0.bn1.running_mean", "backbone.body.layer1.0.bn1.running_var", "backbone.body.layer1.0.conv2.weight".............

Temporary fix

In the meantime, the code below should act as a temporary fix/alternative to the load function:

classes = ['dog', 'cat', 'rabbit']
model_type = Model.MOBILENET_320
file_name = 'path/to/saved_model.pth'

model = Model(classes, model_name=model_type)
model.get_internal_model().load_state_dict(torch.load(file_name, map_location=model._device))

alankbi avatar Mar 13 '22 18:03 alankbi