ViT-pytorch
ViT-pytorch copied to clipboard
Why the kernel is normalized in StdConv2d?
I noticed that you used
class StdConv2d(nn.Conv2d):
def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / torch.sqrt(v + 1e-5)
return F.conv2d(x, w, self.bias, self.stride, self.padding,
self.dilation, self.groups)
Why 'w' is normalized here? Any special consideration for implementing in this way? Thanks
In CNN, weight standardization is suggested in Big Transfer (BiT): General Visual Representation Learning. See section 4.3 of paper.