Support different image size for ViT with relative position encoding.
Currently, timm support different image size in testing time for ViT with absolute position encoding, and ViT with relative position encoding is not supported. However, these ones with relative position encoding is also important. In some application that sensitive to image processing operation, it is unreasonable applying image resize to image.
Incorporating runtime absolute positional embedding interpolation would be beneficial as it allows for the use of various image sizes during inference. This addition should not pose a significant challenge.
- the code for the interpolation of relative pos embed, specifically swin style is doable but a bit more fiddly and need to do it in each block instead of just one spot
- runtime position interpolation for either abs or relative is definitely doable, I have a prototype for vit that works fine and relaxes the constraint from img_size == fixed_value to img_size % patch_size == 0, you can go further to pad and remove all constraint but don't think that's the best idea re performance
- downside of the runtime interpolation are the checks of tensor size, it can cause issues with tracing/export, it also requires shape info from the patch embed so changes interfaces a bit, when keeping things simple, there can be ambiguity in the shape check w/ timms support for non square images (and technically patch size), so it simplies to assume the original model is always square
def forward_features(self, x):
x, grid_size = self.patch_embed(x)
x = self._pos_embed(x, grid_size)
x = self.patch_drop(x)
x = self.norm_pre(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x
def _pos_embed(self, x, grid_size: List[int]):
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + resample_abs_pos_embed(self.pos_embed, grid_size, num_prefix_tokens=0)
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + resample_abs_pos_embed(self.pos_embed, grid_size, num_prefix_tokens=self.num_prefix_tokens)
return self.pos_drop(x)
Now in the ViT code, I can only see relative position encoding. Is there any argument to choose absolute position encoding?