vision
vision copied to clipboard
Inference pre-processing transforms when not using pre-trained weights?
This is somewhat of a niche use-case. What would be a good way to run the pre-processing transforms when we don't need pre-trained weights? This happens in my case because I'm just benchmarking for speed, but it could also be the case of a user not having easy access to the internet.
batch = torch.randint(0, 256, size=(4, 3, 224, 224))
model = models.vit_h_14() # I don't need pre-trained weights
weights = models.ViT_H_14_Weights.DEFAULT # I only need this to pre-process the input
inpt = weights.transforms()(batch)
model(inpt)
fails with
Wrong image height! # BTW it'd be nice to indicate the expected and actual sizes, I had to dig to find that the issue was 518 != 224
because the image_size
parameter wasn't overridden since weights is None
:
https://github.com/pytorch/vision/blob/701d7731660ea54f1ab00c792e9e018569035e2d/torchvision/models/vision_transformer.py#L318-L322
I think in this particular case, we can do:
batch = torch.randint(0, 256, size=(4, 3, 224, 224))
weights = models.ViT_H_14_Weights.DEFAULT
inpt = weights.transforms()(batch)
model = models.vit_h_14(image_size=inpt.shape[-1]) # pass image_size to the model
model(inpt)
And for the assert message, I think this is a good feedback! I create a PR to patch this: #6583
It seems like the issue is resolved. Shall we close this then?