WeightStandardization
WeightStandardization copied to clipboard
Network's performance decreases after adopting WS
I am training a Instance Segmentation network, before I adopt WS, I can achieve mAP 35.66 with Conv+GN, however after adopting WS, I can only achieve 35.27. Is there something wrong with my code? My code to convert the original network to WS is below, note that my original code contains a ResNet101-FPN backbone with deformable convs and depth-separable convs and linear bottlenecks introduced in MobileNet-V2
class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, x):
weight = self.weight
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
keepdim=True).mean(dim=3, keepdim=True)
weight = weight - weight_mean
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
weight = weight / std.expand_as(weight)
return F.conv2d(x, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def convertConv2WeightStand(module, nextChild=None):
mod = module
norm_list = [torch.nn.modules.batchnorm.BatchNorm1d, torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.BatchNorm3d, torch.nn.GroupNorm, torch.nn.LayerNorm]
conv_list = [torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d]
for norm in norm_list:
for conv in conv_list:
if isinstance(mod, conv) and isinstance(nextChild, norm):
mod = Conv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride,
mod.padding, mod.dilation, mod.groups, mod.bias!=None)
moduleChildList = list(module.named_children())
for index, [name, child] in enumerate(moduleChildList):
nextChild = None
if index < len(moduleChildList) -1:
nextChild = moduleChildList[index+1][1]
mod.add_module(name, convertConv2WeightStand(child, nextChild))
return mod
if cfg.useWeightStandardization:
net = convertConv2WeightStand(net)
Thanks for the question. Did you also use the backbones pre-trained with WS? Also, make sure every WS-Conv2d is followed by an activation normalization layer; otherwise, use a regular Conv2d.
Thanks for the question. Did you also use the backbones pre-trained with WS? Also, make sure every WS-Conv2d is followed by an activation normalization layer; otherwise, use a regular Conv2d.
Thanks for your reply. I tried both
- replace conv+BN in backbone+FPN with WS+BN then fine tune
- not replace conv+BN in backbone+FPN
The others network components are replaced with WS. In both situations, I saw performance decreases. I have verified the network's architecture, only conv directly followed by BN are replaced by WS. Here I have a doubt, for combined convs like LinearBottleNeck followed byBN, i.e. 3x3+1x1+3x3+BN, should we replace only the last 3x3 conv with WS or all the three convs? In my code, I choose the former one.
Sorry, it's hard for me to see where the problem might be given the details you provided. However, one thing I would recommend trying is removing weight /= std, i.e. only centering the weights. This would remove the benefits of std but would have more tolerance for different architecture designs. This strategy might also apply to the combined convolutions.
@joe-siyuan-qiao Why is it important that WS has to be followed by a normalization layer? From what I understood, WS aims to preserve the statistics of the tensors. So even for layers without normalization, shouldn't it be useful?
In other words, WS can pass the statistical similarities from the input channels to the output channels, all the way from the image space where RGB channels are properly normalized.
Just for your reference, on my task, GN > GN+WC (weight centralization) >> GN+WS.