EfficientNet-PyTorch
EfficientNet-PyTorch copied to clipboard
GroupNorm and Weight Standardization
Hello,
I'm using EffNets as backbones of EffDets and I'm facing issues with batch sizes and BatchNorm layers due to memory consumptions of large models. To counter that problem I've read about GroupNorm and Weight Standardization. Do you think we can adapt your class Conv2dDynamicSamePadding with the implementation of WS:
class Conv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, x):
weight = self.weight
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
keepdim=True).mean(dim=3, keepdim=True)
weight = weight - weight_mean
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
weight = weight / std.expand_as(weight)
return F.conv2d(x, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
and would it still work ? Renaud