detecto icon indicating copy to clipboard operation
detecto copied to clipboard

Avoiding model download

Open carlok opened this issue 3 years ago • 1 comments

Hi, a two line script like

from detecto import core
model = core.Model.load('mypath', ['foo'])

still tries to download the default model from the Internet.

I have a use case on an HPC with isolated computing nodes.

I have a pth file and I just want to load it and then train it again with fit but I don't know how to do it. Is it possible? Thanks.

carlok avatar Apr 14 '22 07:04 carlok

You could try using core.Model(['foo'], pretrained=False) and the loading the weights manually:

model.get_internal_model().load_state_dict(torch.load(path, map_location=model._device))

Setting pretrained to false will prevent it from downloading a model pre-trained on COCO train2017 - however, it will still likely download a model with a backbone pre-trained on Imagenet (see docs). If you wanted to prevent this and train a model completely from scratch (backbone and all), that's probably beyond the scope of Detecto, but if you're willing to use plain PyTorch, I think it could be accomplished with something along the lines of as follows:

classes = ['foo']
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, len(classes) + 1)

# send model to GPU, train, etc. using plain PyTorch 

alankbi avatar Apr 14 '22 17:04 alankbi