vit-pytorch
vit-pytorch copied to clipboard
About flatening the patch
Hi,
Thanks for sharing this implementation. I got one question when flattening the HxW patch to D dimension, you use a FC layer to map it to a fixed dimension. But in the original paper, they use 224x224 to train and 384x384 to test, which can not be achieved if the flatten layer is fixed. Also, in another repo you shared (https://github.com/rwightman/pytorch-image-models/blob/6f43aeb2526f9fc35cde5262df939ea23a18006c/timm/models/vision_transformer.py#L146), they use 1D conv to avoid resolution mismatch problem. Do you know which one is correct? Thanks!
@yueruchen Hi Yifan! If you make sure that both your image sizes are divisible by the patch size, as long as you instantiate ViT with image_size as the maximum image size you will be using (in your case 384), it should work fine for images you pass in of a smaller size
import torch
from vit_pytorch import ViT
v = ViT(
image_size = 512,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
train_img = torch.randn(1, 3, 256, 256)
test_img = torch.randn(1, 3, 512, 512)
preds_train = v(train_img) # (1, 1000)
preds_test = v(test_img) # (1, 1000)