hi-ml icon indicating copy to clipboard operation
hi-ml copied to clipboard

Stop hardcoding `pretrained` kwarg for image model

Open fepegar opened this issue 3 years ago • 0 comments

import torch
github = 'microsoft/hi-ml:main'
model = torch.hub.load(github, 'biovil_resnet', pretrained=False)

I would expect model to contain randomly-initialized weights, but in

https://github.com/microsoft/hi-ml/blob/681e4281d0439dc554a410924b1fd5878cc31360/hi-ml-multimodal/src/health_multimodal/image/model/model.py#L165

we enforce ImageNet weights.

The easiest solution is to remove hard-coded argument:

encoder = encoder_class(**kwargs) 

A nicer solution might be creating a new parameter, maybe weights, that can be, e.g., random, imagenet or biovil. I think TorchVision now uses a similar syntax:

resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

@Shruthi42 @ozan-oktay

fepegar avatar Sep 28 '22 15:09 fepegar