hi-ml
hi-ml copied to clipboard
Stop hardcoding `pretrained` kwarg for image model
import torch
github = 'microsoft/hi-ml:main'
model = torch.hub.load(github, 'biovil_resnet', pretrained=False)
I would expect model to contain randomly-initialized weights, but in
https://github.com/microsoft/hi-ml/blob/681e4281d0439dc554a410924b1fd5878cc31360/hi-ml-multimodal/src/health_multimodal/image/model/model.py#L165
we enforce ImageNet weights.
The easiest solution is to remove hard-coded argument:
encoder = encoder_class(**kwargs)
A nicer solution might be creating a new parameter, maybe weights, that can be, e.g., random, imagenet or biovil. I think TorchVision now uses a similar syntax:
resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
@Shruthi42 @ozan-oktay