pretrained-models.pytorch icon indicating copy to clipboard operation
pretrained-models.pytorch copied to clipboard

How to use PNASNet5 as encoder in Unet in pytorch

Open Diyago opened this issue 7 years ago • 2 comments

I want use PNASNet5Large as encoder for my Unet here is my wrong aproach for the PNASNet5Large but working for resnet:

class UNetResNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
                 pretrained=False, is_deconv=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d

        if encoder_depth == 34:
            self.encoder = torchvision.models.resnet34(pretrained=pretrained)
            bottom_channel_nr = 512
        elif encoder_depth == 101:
            self.encoder = torchvision.models.resnet101(pretrained=pretrained)
            bottom_channel_nr = 2048
        elif encoder_depth == 152: #this works
            self.encoder = torchvision.models.resnet152(pretrained=pretrained)
            bottom_channel_nr = 2048
        elif encoder_depth == 777: #coded version for the pnasnet
            self.encoder = PNASNet5Large()
            bottom_channel_nr = 4320 #this unknown for me as well


        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Sequential(self.encoder.conv1,
                                   self.encoder.bn1,
                                   self.encoder.relu,
                                   self.pool)

        self.conv2 = self.encoder.layer1 #PNASNet5Large doesn't have such layers
        self.conv3 = self.encoder.layer2
        self.conv4 = self.encoder.layer3
        self.conv5 = self.encoder.layer4
        self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
        
        self.dec5 =  DecoderBlock(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
        self.dec4 = DecoderBlock(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
        self.dec3 = DecoderBlock(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
        self.dec2 = DecoderBlock(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
                                   is_deconv)
        self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
        self.dec0 = ConvRelu(num_filters, num_filters)
        self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        conv5 = self.conv5(conv4)
        center = self.center(conv5)
        dec5 = self.dec5(torch.cat([center, conv5], 1))
        dec4 = self.dec4(torch.cat([dec5, conv4], 1))
        dec3 = self.dec3(torch.cat([dec4, conv3], 1))
        dec2 = self.dec2(torch.cat([dec3, conv2], 1))
        dec1 = self.dec1(dec2)
        dec0 = self.dec0(dec1)
        return self.final(F.dropout2d(dec0, p=self.dropout_2d))
  1. How to get how many bottom channels pnasnet has. It ends up following way:
...
   self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864,
                              in_channels_right=4320, out_channels_right=864)
          self.relu = nn.ReLU()
          self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0)
          self.dropout = nn.Dropout(0.5)
          self.last_linear = nn.Linear(4320, num_classes)

Is 4320 the answer or not, in_channels_left and out_channels_left - something new for me

  1. Resnet has somekind of 4 big layers which I use and encoders in my Unet arch, how get similar layer from pnasnet

I'm using pytorch 3.1 and this is the link to the [Pnasnet directory][1]

  1. AttributeError: 'PNASNet5Large' object has no attribute 'conv1' - so doesn't have conv1 as well

Diyago avatar Sep 08 '18 13:09 Diyago

You must define new forward methods for PNASNet such as layer1, layer2, layer2, etc. So that when you call self.conv1(x), it calls the layer1 method from PNASNet.

Good luck

Cadene avatar Sep 11 '18 02:09 Cadene

class UNetPNASNet(nn.Module):
    def __init__(self, encoder_depth,  num_classes, num_filters=32, dropout_2d=0.2,
                     pretrained=False, is_deconv=False):
            super().__init__()
            self.num_classes = num_classes
            self.dropout_2d = dropout_2d
            self.encoder = PNASNet5Large()
            bottom_channel_nr = 4320
            self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)

            self.dec5  =  DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8,   is_deconv)
            self.dec4  = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
            self.dec3  = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
            self.dec2  = DecoderBlockV2(num_filters * 4 * 4, num_filters * 4 * 4, num_filters, is_deconv)
            self.dec1  = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
            self.dec0  = ConvRelu(num_filters, num_filters)
            self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)

    def forward(self, x):
            features = self.encoder.features(x)
            relued_features = self.encoder.relu(features)
            avg_pooled_features = self.encoder.avg_pool(relued_features)
            center = self.center(avg_pooled_features)
            dec5 = self.dec5(torch.cat([center, avg_pooled_features], 1))
            dec4 = self.dec4(torch.cat([dec5, relued_features], 1))
            dec3 = self.dec3(torch.cat([dec4, features], 1))
            dec2 = self.dec2(dec3)
            dec1 = self.dec1(dec2)
            dec0 = self.dec0(dec1)
            return self.final(F.dropout2d(dec0, p=self.dropout_2d))

Tried this but failed, input image size 128*128

RuntimeError: Given input size: (4320x4x4). Calculated output size: (4320x-6x-6). Output size is too small at /opt/conda/conda-bld/pytorch_1525796793591/work/torch/lib/THCUNN/generic/SpatialAveragePooling.cu:63

Diyago avatar Sep 11 '18 19:09 Diyago