vision
vision copied to clipboard
Swin in this repo + dynamic resolution
📚 The doc issue
Does Swin impl in this repo support arbitrary dynamic execution-time-defined input resolution (same as other backbones)?
Initially Swin was trained to support only one resolution, but then hacks can be done to support arbitrary resolution. Two repos with such hacks:
- https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py
- https://github.com/megvii-research/SOLQ/blob/main/models/swin_transformer.py
Related issues: https://github.com/microsoft/SimMIM/issues/13 https://github.com/microsoft/esvit/issues/17
cc @YosuaMichael
Hi @vadimkantorov , from my understanding the answer is yes.
The swin_transformer implementation in TorchVision dont restrict the input_resolution. However take note that the accuracy may changes depending on your input resolution (or transform).
For example, the original accuracy of swin_transformer using resize_size=232 and crop_size=224 achieve Acc@1 81.472 Acc@5 95.780 in imagenet1k. I have tried to run the validation with resize_size=128 and crop_size=112 on the same dataset and I got: Acc@1 69.248 Acc@5 88.824.
well, resize_size makes think that it's resizing the input instead and always gives a fixed output resolution
from what I understand, the two impl I mentioned above allow for scaling output resolution depending on input resolution
at the very least, the actual behavior should be clearly explained in the docs...
well, resize_size makes think that it's resizing the input instead and always gives a fixed output resolution
from what I understand, the two impl I mentioned above allow for scaling output resolution depending on input resolution
Hi @vadimkantorov , just to clarify, the resize_size that I mentioned is the transform before we do the classification task. It is not part of the swin_transformer parameters (sorry for the confusion!)).
For output resolution can you explain more what do you mean by it?
From what I understand now the swin transformer model on TorchVision is a meant for classification model, so it will definitely output fixed size output according to the num_classes parameter.
For the two links that you provide, I think their implementation are focus as a backbone and the model return the output of all the layers.
I think for all models in torchvision, being able to serve as backbone is top-1 or top-2 usage anyway. People usually figure out how to remove the classifier, and just use the features of some layer. So this is important to clarify in docs and important scenario in general (especially because all earlier models are convolutional and can easily support any input spatial size).
Those two repos above achieve this by hacking the architecture a bit. I guess it won't be a stretch to say that a lot of people are interested in using swin as a better backbone
Probably this repo/detectron2 also has swin adapted as a backbone: https://github.com/facebookresearch/detectron2/tree/main/projects/ViTDet
Hi @vadimkantorov , I can confirm that swin transformer model on torchvision can support different input size. Here is the code to show this:
mport torchvision
import torch
m = torchvision.models.swin_t(weights="DEFAULT")
x1 = torch.rand((1, 3, 224, 224))
x2 = torch.rand((1, 3, 150, 200))
x3 = torch.rand((1, 3, 123, 173))
print(m.features(x1).shape) # torch.Size([1, 7, 7, 768])
print(m.features(x2).shape) # torch.Size([1, 5, 7, 768])
print(m.features(x3).shape) # torch.Size([1, 4, 6, 768])
As we can see here, it can accept different input size and the features output produce different output_resolution.
If we compared to the model in detectron2, the differences is that they have output from all the layers while torchvision only output the last layer. Here is the code for detectron2 to illustrate:
import detectron2.modeling.backbone.swin as swin
m2 = swin.SwinTransformer()
keys = ["p0", "p1", "p2", "p3"]
print({key: val.shape for key, val in m2(x1).items()})
# {'p0': torch.Size([1, 96, 56, 56]), 'p1': torch.Size([1, 192, 28, 28]), 'p2': torch.Size([1, 384, 14, 14]), 'p3': torch.Size([1, 768, 7, 7])}
print({key: val.shape for key, val in m2(x2).items()})
# {'p0': torch.Size([1, 96, 38, 50]), 'p1': torch.Size([1, 192, 19, 25]), 'p2': torch.Size([1, 384, 10, 13]), 'p3': torch.Size([1, 768, 5, 7])}
print({key: val.shape for key, val in m2(x3).items()})
# {'p0': torch.Size([1, 96, 31, 44]), 'p1': torch.Size([1, 192, 16, 22]), 'p2': torch.Size([1, 384, 8, 11]), 'p3': torch.Size([1, 768, 4, 6])}
We can see that the last layer has similar layout (only differ in permutation).