learning-notes icon indicating copy to clipboard operation
learning-notes copied to clipboard

如何从训练好的 PyTorch 模型中提取一幅图像的特征?

Open daa233 opened this issue 6 years ago • 6 comments

在计算 perceptual loss 的时候,需要从一个训练好的 VGG16 或者 VGG19 的中间层提取出待测图像的特征并进行比较。在 PyTorch 中,已经训练好的网络模型很可能是通过 nn.Sequetial 来定义的,中间层的名字未知,该如何进行提取?

How to extract the features of an image from a trained model in PyTorch? https://discuss.pytorch.org/t/how-to-extract-features-of-an-image-from-a-trained-model/119

VGG16 Features

Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))             # conv1_1
    (1): ReLU(inplace)                                                                # relu1_1
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))            # conv1_2
    (3): ReLU(inplace)                                                                # relu1_2
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)   # pool1
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))           # conv2_1
    (6): ReLU(inplace)                                                                # relu2_1
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))          # conv2_2
    (8): ReLU(inplace)                                                                # relu2_2
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)   # pool2
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv3_1
    (11): ReLU(inplace)                                                               # relu3_1
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv3_2
    (13): ReLU(inplace)                                                               # relu3_2
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv3_3
    (15): ReLU(inplace)                                                               # relu3_3
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  # pool3
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv4_1
    (18): ReLU(inplace)                                                               # relu4_1
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv4_2
    (20): ReLU(inplace)                                                               # relu4_2
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv4_3
    (22): ReLU(inplace)                                                               # relu4_3
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  # pool4
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv4_4
    (25): ReLU(inplace)                                                               # relu4_4
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv5_1
    (27): ReLU(inplace)                                                               # relu5_1
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv5_2
    (29): ReLU(inplace)                                                               # relu5_2
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  # pool5
)

VGG19 Features

Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))             # conv1_1
    (1): ReLU(inplace)                                                                # relu1_1
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))            # conv1_2
    (3): ReLU(inplace)                                                                # relu1_2
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)   # pool1
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))           # conv2_1
    (6): ReLU(inplace)                                                                # relu2_1
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))          # conv2_2
    (8): ReLU(inplace)                                                                # relu2_2
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)   # pool2
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv3_1
    (11): ReLU(inplace)                                                               # relu3_1
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv3_2
    (13): ReLU(inplace)                                                               # relu3_2
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv3_3
    (15): ReLU(inplace)                                                               # relu3_3
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv3_4
    (17): ReLU(inplace)                                                               # relu3_4
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  # pool3
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv4_1
    (20): ReLU(inplace)                                                               # relu4_1
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv4_2
    (22): ReLU(inplace)                                                               # relu4_2
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv4_3
    (24): ReLU(inplace)                                                               # relu4_3
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv4_4
    (26): ReLU(inplace)                                                               # relu4_4
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  # pool4
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv5_1
    (29): ReLU(inplace)                                                               # relu5_1
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv5_2
    (31): ReLU(inplace)                                                               # relu5_2
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv5_3
    (33): ReLU(inplace)                                                               # relu5_3
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))         # conv5_4
    (35): ReLU(inplace)                                                               # relu5_4
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  # pool5
)

daa233 avatar Aug 06 '18 13:08 daa233

方法 1. 重新定义一个相同结构的网络,并赋予中间模块确定的名字,然后计算出对应层的输出并返回

例如:

class Vgg16(nn.Module):
    def __init__(self):
        super(Vgg16, self).__init__()
        self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)

        self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)

        self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)

        self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)

        self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)

    def forward(self, X):
        h = F.relu(self.conv1_1(X), inplace=True)
        h = F.relu(self.conv1_2(h), inplace=True)
        # relu1_2 = h
        h = F.max_pool2d(h, kernel_size=2, stride=2)

        h = F.relu(self.conv2_1(h), inplace=True)
        h = F.relu(self.conv2_2(h), inplace=True)
        # relu2_2 = h
        h = F.max_pool2d(h, kernel_size=2, stride=2)

        h = F.relu(self.conv3_1(h), inplace=True)
        h = F.relu(self.conv3_2(h), inplace=True)
        h = F.relu(self.conv3_3(h), inplace=True)
        # relu3_3 = h
        h = F.max_pool2d(h, kernel_size=2, stride=2)

        h = F.relu(self.conv4_1(h), inplace=True)
        h = F.relu(self.conv4_2(h), inplace=True)
        h = F.relu(self.conv4_3(h), inplace=True)
        # relu4_3 = h

        h = F.relu(self.conv5_1(h), inplace=True)
        h = F.relu(self.conv5_2(h), inplace=True)
        h = F.relu(self.conv5_3(h), inplace=True)
        relu5_3 = h

        return relu5_3
        # return [relu1_2, relu2_2, relu3_3, relu4_3]

参考 https://github.com/NVlabs/MUNIT/blob/master/networks.py#L393-L442

之后再利用 vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight'))) 类似的语句来导入模型。

daa233 avatar Aug 06 '18 14:08 daa233

方法 2. 不显式地定义新的网络,通过迭代网络模块计算输出,最后返回

例如:

class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        self.submodule = submodule

    def forward(self, x):
        outputs = []
        for name, module in self.submodule._modules.items():
            x = module(x)
            if name in self.extracted_layers:
                outputs += [x]
        return outputs + [x]

参考 https://discuss.pytorch.org/t/how-to-extract-features-of-an-image-from-a-trained-model/119/13

这样做的好处是可以保持网络模块 stateless,即不存储 intermediate states,以免不经意间用到,导致内存一直占用。

daa233 avatar Aug 06 '18 14:08 daa233

方法3. (仅支持 nn.Sequential 模块)不显式地定义网络,直接对原模型进行索引,获取子模块

例如:

from PIL import Image
from torchvision import models, transforms

model = models.vgg16_bn(pretrained=True).features[:37]

img = transforms.Resize((224, 224))(Image.open(img_path).convert('RGB'))
img = transforms.ToTensor()(img)
img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
img = img.view(-1, *template_img.shape)
img = img.to('cuda')
feature = model(img.detach())

daa233 avatar Nov 03 '18 02:11 daa233

方法 2. 不显式地定义新的网络,通过迭代网络模块计算输出,最后返回

例如:

class FeatureExtractor(nn.Module):
    def __init__(self, submodule, extracted_layers):
        self.submodule = submodule

    def forward(self, x):
        outputs = []
        for name, module in self.submodule._modules.items():
            x = module(x)
            if name in self.extracted_layers:
                outputs += [x]
        return outputs + [x]

参考 https://discuss.pytorch.org/t/how-to-extract-features-of-an-image-from-a-trained-model/119/13

这样做的好处是可以保持网络模块 stateless,即不存储 intermediate states,以免不经意间用到,导致内存一直占用。

方法二一次只能得到一个batch的输出,当载入下一个batch的时候,之前的数据不是就丢失了吗?

Hiker01 avatar Apr 27 '19 08:04 Hiker01

@Hiker01 这里给出的方法都是针对一个 batch 的数据或者一幅图像的,如果需要多个 batch 的处理,需要在外面写迭代。例如有一个 dataloader 提供数据,一个特征提取器 feature_extractor,会有类似这样的结构,可以把所有 batch 的数据放到 feature_list 里:

feature_list = []
for iter, images in enumerate(some_dataloader):
    feature_list += [feature_extractor(images)]

daa233 avatar Apr 28 '19 01:04 daa233

@daa233 @Hiker01 多batch单gpu的话,可以采用“外面写迭代”的方法。 如果你有多个gpu的话,可以用nn.DataParallel,把网络复制到多个gpu上,然后多个batch同时forward得到多batchsize的output

whubaichuan avatar Mar 23 '20 01:03 whubaichuan