candle
candle copied to clipboard
Positional encoding in DINOv2: Case without interpolation
Hi,
I have been reviewing DINOv2 Candle code and I noticed most likely a bug (unless I misunderstood the code).
As far as I understand, the function interpolate_pos_encoding() is used to interpolate the (transformer-wise absolute) positional encoding before being added to the transformer tokens 'xs':
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _nc, w, h) = xs.dims4()?;
let xs = self.patch_embed.forward(xs)?;
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
}
But the interpolation function seems buggy in the case where there is no interpolation, and the "shortcut" is used. See the beginning of the function:
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
let npatch = xs.dim(1)? - 1;
let n = self.pos_embed.dim(1)? - 1;
let sqrt_n = (n as f64).sqrt();
if npatch == n && w == h {
return Ok(xs.clone());
}
let class_pos_embed = self.pos_embed.i((.., ..1))?;
let patch_pos_embed = self.pos_embed.i((.., 1..))?;
I guess inside the shortcut, we should return this instead:
if npatch == n && w == h {
return Ok(self.pos_embed)
}
Otherwise, we are doubling the values of the tokens xs ?!
Thanks for your answer,