Deformable-DETR
Deformable-DETR copied to clipboard
There is a mistake in deformable_attn.py
At line 215-219 of deformable_attn.py
, the original code is
# B*M, H*W, C_v
feat = torch.einsum('nlds, nls -> nld', scale_features, A)
# B, H, W, C
feat = feat.view(nbatches, query_height, query_width, self.d_k * self.h)
I believe it's a mistake, it should be
# B*M, H*W, C_v
feat = torch.einsum('nlds, nls -> nld', scale_features, A)
# B, M, H*W, C_v
feat = feat.view(nbatches, self.h, query_height * query_width, self.d_k)
# B, H*W, M, C_v
feat = feat.permute(0, 2, 1, 3).contiguous()
# B, H, W, C
feat = feat.view(nbatches, query_height, query_width, self.d_k * self.h)
@ParadoxZW hi, it is. PR is welcome.