pytorch-image-models icon indicating copy to clipboard operation
pytorch-image-models copied to clipboard

Support different image size for ViT with relative position encoding.

Open Luciennnnnnn opened this issue 2 years ago • 4 comments

Currently, timm support different image size in testing time for ViT with absolute position encoding, and ViT with relative position encoding is not supported. However, these ones with relative position encoding is also important. In some application that sensitive to image processing operation, it is unreasonable applying image resize to image.

Luciennnnnnn avatar May 15 '23 09:05 Luciennnnnnn

Incorporating runtime absolute positional embedding interpolation would be beneficial as it allows for the use of various image sizes during inference. This addition should not pose a significant challenge.

leng-yue avatar May 16 '23 03:05 leng-yue

  • the code for the interpolation of relative pos embed, specifically swin style is doable but a bit more fiddly and need to do it in each block instead of just one spot
  • runtime position interpolation for either abs or relative is definitely doable, I have a prototype for vit that works fine and relaxes the constraint from img_size == fixed_value to img_size % patch_size == 0, you can go further to pad and remove all constraint but don't think that's the best idea re performance
  • downside of the runtime interpolation are the checks of tensor size, it can cause issues with tracing/export, it also requires shape info from the patch embed so changes interfaces a bit, when keeping things simple, there can be ambiguity in the shape check w/ timms support for non square images (and technically patch size), so it simplies to assume the original model is always square

rwightman avatar May 17 '23 17:05 rwightman

    def forward_features(self, x):
        x, grid_size = self.patch_embed(x)
        x = self._pos_embed(x, grid_size)
        x = self.patch_drop(x)
        x = self.norm_pre(x)
        if self.grad_checkpointing and not torch.jit.is_scripting():
            x = checkpoint_seq(self.blocks, x)
        else:
            x = self.blocks(x)
        x = self.norm(x)
        return x
    def _pos_embed(self, x, grid_size: List[int]):
        if self.no_embed_class:
            # deit-3, updated JAX (big vision)
            # position embedding does not overlap with class token, add then concat
            x = x + resample_abs_pos_embed(self.pos_embed, grid_size, num_prefix_tokens=0)
            if self.cls_token is not None:
                x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
        else:
            # original timm, JAX, and deit vit impl
            # pos_embed has entry for class token, concat then add
            if self.cls_token is not None:
                x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
            x = x + resample_abs_pos_embed(self.pos_embed, grid_size, num_prefix_tokens=self.num_prefix_tokens)
        return self.pos_drop(x)

rwightman avatar May 17 '23 17:05 rwightman

Now in the ViT code, I can only see relative position encoding. Is there any argument to choose absolute position encoding?

osiriszjq avatar Feb 20 '24 00:02 osiriszjq