EfficientNet-PyTorch icon indicating copy to clipboard operation
EfficientNet-PyTorch copied to clipboard

How to add additional layer in pre-trained model?

Open talhaanwarch opened this issue 4 years ago • 3 comments

Can you please guide me how to add some extra fully connected layer on top of a pre-trained model

from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b0')

I am confused, how to access the last layer and connect with another layer

talhaanwarch avatar Jul 14 '20 20:07 talhaanwarch

You can do something like

model._fc = nn.Sequential(nn.Linear(self.network._fc.in_features, 512), 
                                           nn.ReLU(),  
                                           nn.Dropout(0.25),
                                           nn.Linear(512, 128), 
                                           nn.ReLU(),  
                                           nn.Dropout(0.50), 
                                           nn.Linear(128,classes))

or if you want to make bigger changes:

class MyEfficientNet(nn.Module):

    def __init__(self):
        super().__init__()

        # EfficientNet
        self.network = EfficientNet.from_pretrained("efficientnet-b0")
        
        # Replace last layer
        self.network._fc = nn.Sequential(nn.Linear(self.network._fc.in_features, 512), 
                                         nn.ReLU(),  
                                         nn.Dropout(0.25),
                                         nn.Linear(512, 128), 
                                         nn.ReLU(),  
                                         nn.Dropout(0.50), 
                                         nn.Linear(128,classes))
    
    def forward(self, x):
        out = self.network(x)
        return out

model = MyEfficientNet()

Look good?

lukemelas avatar Jul 14 '20 20:07 lukemelas

just wondering if the last layer will still have a swish activation? When I print out the model, that seems to be the case. If so how do you remove that last layer?

Last few lines of output of print(model).

(_bn1): BatchNorm2d(1280, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_avg_pooling): AdaptiveAvgPool2d(output_size=1)
    (_dropout): Dropout(p=0.2, inplace=False)
    (_fc): Sequential(
      (0): Linear(in_features=1280, out_features=512, bias=True)
      (1): ReLU()
      (2): Dropout(p=0.25, inplace=False)
      (3): Linear(in_features=512, out_features=128, bias=True)
      (4): ReLU()
      (5): Dropout(p=0.25, inplace=False)
      (6): Linear(in_features=128, out_features=1, bias=True)
    )
    (_swish): MemoryEfficientSwish()
  )
)

sachinruk avatar Jul 16 '20 12:07 sachinruk

I've expanded on the question above on my SO question.

sachinruk avatar Jul 17 '20 13:07 sachinruk