EfficientNet-PyTorch icon indicating copy to clipboard operation
EfficientNet-PyTorch copied to clipboard

How to change BN to GN and FRN layer in efficientnet?

Open mobassir94 opened this issue 5 years ago • 1 comments

i was trying groupnormalization for efficientnet , my model code is :

out_dim = 5
enet_type = 'efficientnet-b0'

pretrained_model = {
    'efficientnet-b0': '../input/efficientnet-pytorch/efficientnet-b0-08094119.pth'
}

    
class enetv2(nn.Module):
    def __init__(self, backbone, out_dim):
        super(enetv2, self).__init__()
        self.enet = enet.EfficientNet.from_name(backbone)
        self.enet.load_state_dict(torch.load(pretrained_model[backbone]))

        self.myfc = nn.Linear(self.enet._fc.in_features, out_dim)
        self.enet._fc = nn.Identity()

    def extract(self, x):
        return self.enet(x)

    def forward(self, x):
        x = self.extract(x)
        x = self.myfc(x)
        return x
model = enetv2(enet_type, out_dim=out_dim)
model = model.to(device)

if i try model.enet._bn0 then it gives me this output : BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)

but with this code :

for name, module in model.named_modules():
    if isinstance(module, nn.BatchNorm2d):
        # Get current bn layer
        bn = getattr(model, name)
        # Create new gn layer
        gn = nn.GroupNorm(1, bn.num_features)
        # Assign gn
        print('Swapping {} with {}'.format(bn, gn))
        setattr(model, name, gn)

print(model)

i get this error :

AttributeError Traceback (most recent call last)
in
2 if isinstance(module, nn.BatchNorm2d):
3 # Get current bn layer
----> 4 bn = getattr(model, name)
5 # Create new gn layer
6 gn = nn.GroupNorm(1, bn.num_features)

/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in getattr(self, name)
592 return modules[name]
593 raise AttributeError("’{}’ object has no attribute ‘{}’".format(
–> 594 type(self).name, name))
595
596 def setattr(self, name, value):

AttributeError: ‘enetv2’ object has no attribute ‘enet._bn0’

now how do i replace BN layers with GN layers in my model?

i would also like to know how to change those BN layers with Filter Response Normalization (FRN) layers in my model

mobassir94 avatar Jun 27 '20 18:06 mobassir94

def convert_BN2GN(model):
       for name, module in model._modules.items(): 
            if len(list(module.children())) > 0:  
                    convert_BN2GN(module)
            elif isinstance(module, nn.BatchNorm2d): 
                     module_tmp = nn.GroupNorm(32, module.num_features)
                     model._modules[name] = module_tmp

       return model

Lg955 avatar Jul 29 '21 08:07 Lg955