ViT-pytorch icon indicating copy to clipboard operation
ViT-pytorch copied to clipboard

Why the kernel is normalized in StdConv2d?

Open xychenunc opened this issue 3 years ago • 1 comments

I noticed that you used

class StdConv2d(nn.Conv2d):

def forward(self, x):
    w = self.weight
    v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
    w = (w - m) / torch.sqrt(v + 1e-5)
    return F.conv2d(x, w, self.bias, self.stride, self.padding,
                    self.dilation, self.groups)

Why 'w' is normalized here? Any special consideration for implementing in this way? Thanks

xychenunc avatar Mar 09 '21 15:03 xychenunc

In CNN, weight standardization is suggested in Big Transfer (BiT): General Visual Representation Learning. See section 4.3 of paper.

jeonsworld avatar Apr 16 '21 04:04 jeonsworld