dinov2
dinov2 copied to clipboard
Positional encoding fails with rectangular input
Hi! I stumbled on the same issue as interpolate_pos_encoding(x, pos_embed) doesnt return correct dimension for images that is not square (w != h) when using dinov2, the code crashed on the same function when using rectangular input...
Here is the function from vision_transformer.py
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
print(f'DEBUG dinov2 vision_trasnformer.py: w0={w0}, h0={h0}')
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
print(f'DEBUG dinov2 vision_trasnformer.py: add small number w0={w0}, h0={h0}')
sqrt_N = math.sqrt(N)
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy),
mode="bicubic",
antialias=self.interpolate_antialias,
)
print(f'DEBUG dinov2 vision_trasnformer.py: patch_pos_embed.shape={patch_pos_embed.shape}')
assert int(w0) == patch_pos_embed.shape[-2]
assert int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
I checked the output by feeding a rectangular image and found out the small addition did not change the w and h, see the output:
image_crop.shape=torch.Size([1, 3, 434, 546])
DEBUG dinov2 vision_trasnformer.py: w0=31, h0=39
DEBUG dinov2 vision_trasnformer.py: add small number w0=31.0, h0=39.0
DEBUG dinov2 vision_trasnformer.py: patch_pos_embed.shape=torch.Size([1, 384, 31, 38])
Note that in the init, interpolate_offset=0.1
.
Here are the errors I got:
File "home/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py", line 204, in interpolate_pos_encoding
assert int(w0) == patch_pos_embed.shape[-2]
AssertionError
Note: used pretrained dinov2_vits14_reg model.
Could it be that those pretrained checkpoints have no field interpolate_offset, or it is set to 0?
Meanwhile, I will manually set it to 0.1 and see if it helps.
Update: adding manually 0.1 instead of self.interpolate_offset
solved the issue. This means the pretrained checkpoint has it set to 0.0, I assume.
Is there a reason the interpolate call doesn't set the output size directly to (w0,h0) using the size parameter, rather than using the scale_factor parameter?
Update: adding manually 0.1 instead of
self.interpolate_offset
solved the issue. This means the pretrained checkpoint has it set to 0.0, I assume.
I'm facing the same issue. It only fails for certain widths/length: 546, 602, 1092, ...
This should hopefully be fixed now with #378.
Hi @patricklabatut, I am trying to train on a rectangular image and I get the following assertion error even after the merged PR.
https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L192
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/home/zsuri/miniconda3/envs/dinov2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/zsuri/dinov2/./dinov2/models/vision_transformer.py", line 325, in forward
ret = self.forward_features(*args, **kwargs)
File "/home/zsuri/dinov2/./dinov2/models/vision_transformer.py", line 258, in forward_features
x = self.prepare_tokens_with_masks(x, masks)
File "/home/zsuri/dinov2/./dinov2/models/vision_transformer.py", line 220, in prepare_tokens_with_masks
x = x + self.interpolate_pos_encoding(x, w, h)
File "/home/zsuri/dinov2/./dinov2/models/vision_transformer.py", line 192, in interpolate_pos_encoding
assert N == M * M
AssertionError
It can be replicated as follows
model = SSLMetaArch(setup(args)))
model.teacher.backbone(torch.zeros((1,3,1078, 1918))
with the default vitl14.yaml
config.
Also, I don't understand the intuition behind these lines https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L185-L211 for rectangular images. I tested with returning https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L184 even when h != w and it seems to run. I'm not sure if that produces any unwanted outcomes though.