How to use PNASNet5 as encoder in Unet in pytorch
I want use PNASNet5Large as encoder for my Unet here is my wrong aproach for the PNASNet5Large but working for resnet:
class UNetResNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
pretrained=False, is_deconv=False):
super().__init__()
self.num_classes = num_classes
self.dropout_2d = dropout_2d
if encoder_depth == 34:
self.encoder = torchvision.models.resnet34(pretrained=pretrained)
bottom_channel_nr = 512
elif encoder_depth == 101:
self.encoder = torchvision.models.resnet101(pretrained=pretrained)
bottom_channel_nr = 2048
elif encoder_depth == 152: #this works
self.encoder = torchvision.models.resnet152(pretrained=pretrained)
bottom_channel_nr = 2048
elif encoder_depth == 777: #coded version for the pnasnet
self.encoder = PNASNet5Large()
bottom_channel_nr = 4320 #this unknown for me as well
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu,
self.pool)
self.conv2 = self.encoder.layer1 #PNASNet5Large doesn't have such layers
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4
self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
self.dec5 = DecoderBlock(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlock(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlock(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec2 = DecoderBlock(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
is_deconv)
self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec0 = ConvRelu(num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
center = self.center(conv5)
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(dec2)
dec0 = self.dec0(dec1)
return self.final(F.dropout2d(dec0, p=self.dropout_2d))
- How to get how many bottom channels pnasnet has. It ends up following way:
...
self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864,
in_channels_right=4320, out_channels_right=864)
self.relu = nn.ReLU()
self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0)
self.dropout = nn.Dropout(0.5)
self.last_linear = nn.Linear(4320, num_classes)
Is 4320 the answer or not, in_channels_left and out_channels_left - something new for me
- Resnet has somekind of 4 big layers which I use and encoders in my Unet arch, how get similar layer from pnasnet
I'm using pytorch 3.1 and this is the link to the [Pnasnet directory][1]
- AttributeError: 'PNASNet5Large' object has no attribute 'conv1' - so doesn't have conv1 as well
You must define new forward methods for PNASNet such as layer1, layer2, layer2, etc. So that when you call self.conv1(x), it calls the layer1 method from PNASNet.
Good luck
class UNetPNASNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
pretrained=False, is_deconv=False):
super().__init__()
self.num_classes = num_classes
self.dropout_2d = dropout_2d
self.encoder = PNASNet5Large()
bottom_channel_nr = 4320
self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
self.dec5 = DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec2 = DecoderBlockV2(num_filters * 4 * 4, num_filters * 4 * 4, num_filters, is_deconv)
self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec0 = ConvRelu(num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
features = self.encoder.features(x)
relued_features = self.encoder.relu(features)
avg_pooled_features = self.encoder.avg_pool(relued_features)
center = self.center(avg_pooled_features)
dec5 = self.dec5(torch.cat([center, avg_pooled_features], 1))
dec4 = self.dec4(torch.cat([dec5, relued_features], 1))
dec3 = self.dec3(torch.cat([dec4, features], 1))
dec2 = self.dec2(dec3)
dec1 = self.dec1(dec2)
dec0 = self.dec0(dec1)
return self.final(F.dropout2d(dec0, p=self.dropout_2d))
Tried this but failed, input image size 128*128
RuntimeError: Given input size: (4320x4x4). Calculated output size: (4320x-6x-6). Output size is too small at /opt/conda/conda-bld/pytorch_1525796793591/work/torch/lib/THCUNN/generic/SpatialAveragePooling.cu:63