detecto
detecto copied to clipboard
Model.load does not support non-default Model types
Describe the bug
If someone saves a model file using a non-default model type (Model.MOBILENET or Model.MOBILENET_320) and later tries to load it using Model.load, it will cause the following error:
RuntimeError: Error(s) in loading state_dict for FasterRCNN:
Missing key(s) in state_dict: "backbone.body.conv1.weight", "backbone.body.bn1.weight", "backbone.body.bn1.bias", "backbone.body.bn1.running_mean", "backbone.body.bn1.running_var", "backbone.body.layer1.0.conv1.weight", "backbone.body.layer1.0.bn1.weight", "backbone.body.layer1.0.bn1.bias", "backbone.body.layer1.0.bn1.running_mean", "backbone.body.layer1.0.bn1.running_var", "backbone.body.layer1.0.conv2.weight".............
Temporary fix
In the meantime, the code below should act as a temporary fix/alternative to the load function:
classes = ['dog', 'cat', 'rabbit']
model_type = Model.MOBILENET_320
file_name = 'path/to/saved_model.pth'
model = Model(classes, model_name=model_type)
model.get_internal_model().load_state_dict(torch.load(file_name, map_location=model._device))