dinov2 icon indicating copy to clipboard operation
dinov2 copied to clipboard

Positional encoding fails with rectangular input

Open alexaatm opened this issue 1 year ago • 2 comments

Hi! I stumbled on the same issue as interpolate_pos_encoding(x, pos_embed) doesnt return correct dimension for images that is not square (w != h) when using dinov2, the code crashed on the same function when using rectangular input...

Here is the function from vision_transformer.py

def interpolate_pos_encoding(self, x, w, h):
        previous_dtype = x.dtype
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1
        if npatch == N and w == h:
            return self.pos_embed
        pos_embed = self.pos_embed.float()
        class_pos_embed = pos_embed[:, 0]
        patch_pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        w0 = w // self.patch_size
        h0 = h // self.patch_size
        print(f'DEBUG dinov2 vision_trasnformer.py: w0={w0}, h0={h0}')
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
        print(f'DEBUG dinov2 vision_trasnformer.py: add small number w0={w0}, h0={h0}')

        sqrt_N = math.sqrt(N)
        sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
            scale_factor=(sx, sy),
            mode="bicubic",
            antialias=self.interpolate_antialias,
        )
        print(f'DEBUG dinov2 vision_trasnformer.py: patch_pos_embed.shape={patch_pos_embed.shape}')

        assert int(w0) == patch_pos_embed.shape[-2]
        assert int(h0) == patch_pos_embed.shape[-1]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)

I checked the output by feeding a rectangular image and found out the small addition did not change the w and h, see the output:

 image_crop.shape=torch.Size([1, 3, 434, 546])
DEBUG dinov2 vision_trasnformer.py: w0=31, h0=39
DEBUG dinov2 vision_trasnformer.py: add small number w0=31.0, h0=39.0
DEBUG dinov2 vision_trasnformer.py: patch_pos_embed.shape=torch.Size([1, 384, 31, 38])

Note that in the init, interpolate_offset=0.1. Here are the errors I got:

File "home/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py", line 204, in interpolate_pos_encoding
    assert int(w0) == patch_pos_embed.shape[-2]
AssertionError

Note: used pretrained dinov2_vits14_reg model.

Could it be that those pretrained checkpoints have no field interpolate_offset, or it is set to 0?

Meanwhile, I will manually set it to 0.1 and see if it helps.

alexaatm avatar Dec 07 '23 15:12 alexaatm

Update: adding manually 0.1 instead of self.interpolate_offset solved the issue. This means the pretrained checkpoint has it set to 0.0, I assume.

alexaatm avatar Dec 07 '23 15:12 alexaatm

Is there a reason the interpolate call doesn't set the output size directly to (w0,h0) using the size parameter, rather than using the scale_factor parameter?

steve-landers avatar Jan 05 '24 17:01 steve-landers

Update: adding manually 0.1 instead of self.interpolate_offset solved the issue. This means the pretrained checkpoint has it set to 0.0, I assume.

I'm facing the same issue. It only fails for certain widths/length: 546, 602, 1092, ...

sonsnix avatar Feb 07 '24 11:02 sonsnix

This should hopefully be fixed now with #378.

patricklabatut avatar Feb 22 '24 18:02 patricklabatut

Hi @patricklabatut, I am trying to train on a rectangular image and I get the following assertion error even after the merged PR.

https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L192

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/zsuri/miniconda3/envs/dinov2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/zsuri/dinov2/./dinov2/models/vision_transformer.py", line 325, in forward
    ret = self.forward_features(*args, **kwargs)
  File "/home/zsuri/dinov2/./dinov2/models/vision_transformer.py", line 258, in forward_features
    x = self.prepare_tokens_with_masks(x, masks)
  File "/home/zsuri/dinov2/./dinov2/models/vision_transformer.py", line 220, in prepare_tokens_with_masks
    x = x + self.interpolate_pos_encoding(x, w, h)
  File "/home/zsuri/dinov2/./dinov2/models/vision_transformer.py", line 192, in interpolate_pos_encoding
    assert N == M * M
AssertionError

It can be replicated as follows


model = SSLMetaArch(setup(args)))
model.teacher.backbone(torch.zeros((1,3,1078, 1918))

with the default vitl14.yaml config.

Also, I don't understand the intuition behind these lines https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L185-L211 for rectangular images. I tested with returning https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L184 even when h != w and it seems to run. I'm not sure if that produces any unwanted outcomes though.

zshn25 avatar Apr 29 '24 12:04 zshn25