WeightStandardization icon indicating copy to clipboard operation
WeightStandardization copied to clipboard

Network's performance decreases after adopting WS

Open NoOneUST opened this issue 4 years ago • 5 comments

I am training a Instance Segmentation network, before I adopt WS, I can achieve mAP 35.66 with Conv+GN, however after adopting WS, I can only achieve 35.27. Is there something wrong with my code? My code to convert the original network to WS is below, note that my original code contains a ResNet101-FPN backbone with deformable convs and depth-separable convs and linear bottlenecks introduced in MobileNet-V2

class Conv2d(nn.Conv2d):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                 padding, dilation, groups, bias)


    def forward(self, x):
        weight = self.weight
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                  keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
        weight = weight / std.expand_as(weight)
        return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

def convertConv2WeightStand(module, nextChild=None):
    mod = module
    norm_list = [torch.nn.modules.batchnorm.BatchNorm1d, torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.BatchNorm3d, torch.nn.GroupNorm, torch.nn.LayerNorm]
    conv_list = [torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d]
    for norm in norm_list:
        for conv in conv_list:
            if isinstance(mod, conv) and isinstance(nextChild, norm):
                mod = Conv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride,
                 mod.padding, mod.dilation, mod.groups, mod.bias!=None)

    moduleChildList = list(module.named_children())
    for index, [name, child] in enumerate(moduleChildList):
        nextChild = None
        if index < len(moduleChildList) -1:
            nextChild = moduleChildList[index+1][1]
        mod.add_module(name, convertConv2WeightStand(child, nextChild))

    return mod

if cfg.useWeightStandardization:
    net = convertConv2WeightStand(net)

NoOneUST avatar Mar 05 '20 06:03 NoOneUST

Thanks for the question. Did you also use the backbones pre-trained with WS? Also, make sure every WS-Conv2d is followed by an activation normalization layer; otherwise, use a regular Conv2d.

joe-siyuan-qiao avatar Mar 06 '20 21:03 joe-siyuan-qiao

Thanks for the question. Did you also use the backbones pre-trained with WS? Also, make sure every WS-Conv2d is followed by an activation normalization layer; otherwise, use a regular Conv2d.

Thanks for your reply. I tried both

  1. replace conv+BN in backbone+FPN with WS+BN then fine tune
  2. not replace conv+BN in backbone+FPN

The others network components are replaced with WS. In both situations, I saw performance decreases. I have verified the network's architecture, only conv directly followed by BN are replaced by WS. Here I have a doubt, for combined convs like LinearBottleNeck followed byBN, i.e. 3x3+1x1+3x3+BN, should we replace only the last 3x3 conv with WS or all the three convs? In my code, I choose the former one.

NoOneUST avatar Mar 07 '20 06:03 NoOneUST

Sorry, it's hard for me to see where the problem might be given the details you provided. However, one thing I would recommend trying is removing weight /= std, i.e. only centering the weights. This would remove the benefits of std but would have more tolerance for different architecture designs. This strategy might also apply to the combined convolutions.

joe-siyuan-qiao avatar Mar 09 '20 02:03 joe-siyuan-qiao

@joe-siyuan-qiao Why is it important that WS has to be followed by a normalization layer? From what I understood, WS aims to preserve the statistics of the tensors. So even for layers without normalization, shouldn't it be useful?

In other words, WS can pass the statistical similarities from the input channels to the output channels, all the way from the image space where RGB channels are properly normalized.

gautamsreekumar avatar Mar 29 '21 17:03 gautamsreekumar

Just for your reference, on my task, GN > GN+WC (weight centralization) >> GN+WS.

hiyyg avatar Nov 07 '21 05:11 hiyyg