learning-notes
learning-notes copied to clipboard
如何从训练好的 PyTorch 模型中提取一幅图像的特征?
在计算 perceptual loss 的时候,需要从一个训练好的 VGG16 或者 VGG19 的中间层提取出待测图像的特征并进行比较。在 PyTorch 中,已经训练好的网络模型很可能是通过 nn.Sequetial
来定义的,中间层的名字未知,该如何进行提取?
How to extract the features of an image from a trained model in PyTorch? https://discuss.pytorch.org/t/how-to-extract-features-of-an-image-from-a-trained-model/119
VGG16 Features
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv1_1
(1): ReLU(inplace) # relu1_1
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv1_2
(3): ReLU(inplace) # relu1_2
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # pool1
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv2_1
(6): ReLU(inplace) # relu2_1
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv2_2
(8): ReLU(inplace) # relu2_2
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # pool2
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv3_1
(11): ReLU(inplace) # relu3_1
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv3_2
(13): ReLU(inplace) # relu3_2
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv3_3
(15): ReLU(inplace) # relu3_3
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # pool3
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv4_1
(18): ReLU(inplace) # relu4_1
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv4_2
(20): ReLU(inplace) # relu4_2
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv4_3
(22): ReLU(inplace) # relu4_3
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # pool4
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv4_4
(25): ReLU(inplace) # relu4_4
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv5_1
(27): ReLU(inplace) # relu5_1
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv5_2
(29): ReLU(inplace) # relu5_2
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # pool5
)
VGG19 Features
Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv1_1
(1): ReLU(inplace) # relu1_1
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv1_2
(3): ReLU(inplace) # relu1_2
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # pool1
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv2_1
(6): ReLU(inplace) # relu2_1
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv2_2
(8): ReLU(inplace) # relu2_2
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # pool2
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv3_1
(11): ReLU(inplace) # relu3_1
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv3_2
(13): ReLU(inplace) # relu3_2
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv3_3
(15): ReLU(inplace) # relu3_3
(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv3_4
(17): ReLU(inplace) # relu3_4
(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # pool3
(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv4_1
(20): ReLU(inplace) # relu4_1
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv4_2
(22): ReLU(inplace) # relu4_2
(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv4_3
(24): ReLU(inplace) # relu4_3
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv4_4
(26): ReLU(inplace) # relu4_4
(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # pool4
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv5_1
(29): ReLU(inplace) # relu5_1
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv5_2
(31): ReLU(inplace) # relu5_2
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv5_3
(33): ReLU(inplace) # relu5_3
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # conv5_4
(35): ReLU(inplace) # relu5_4
(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) # pool5
)
方法 1. 重新定义一个相同结构的网络,并赋予中间模块确定的名字,然后计算出对应层的输出并返回
例如:
class Vgg16(nn.Module):
def __init__(self):
super(Vgg16, self).__init__()
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
def forward(self, X):
h = F.relu(self.conv1_1(X), inplace=True)
h = F.relu(self.conv1_2(h), inplace=True)
# relu1_2 = h
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv2_1(h), inplace=True)
h = F.relu(self.conv2_2(h), inplace=True)
# relu2_2 = h
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv3_1(h), inplace=True)
h = F.relu(self.conv3_2(h), inplace=True)
h = F.relu(self.conv3_3(h), inplace=True)
# relu3_3 = h
h = F.max_pool2d(h, kernel_size=2, stride=2)
h = F.relu(self.conv4_1(h), inplace=True)
h = F.relu(self.conv4_2(h), inplace=True)
h = F.relu(self.conv4_3(h), inplace=True)
# relu4_3 = h
h = F.relu(self.conv5_1(h), inplace=True)
h = F.relu(self.conv5_2(h), inplace=True)
h = F.relu(self.conv5_3(h), inplace=True)
relu5_3 = h
return relu5_3
# return [relu1_2, relu2_2, relu3_3, relu4_3]
参考 https://github.com/NVlabs/MUNIT/blob/master/networks.py#L393-L442
之后再利用 vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight')))
类似的语句来导入模型。
方法 2. 不显式地定义新的网络,通过迭代网络模块计算输出,最后返回
例如:
class FeatureExtractor(nn.Module):
def __init__(self, submodule, extracted_layers):
self.submodule = submodule
def forward(self, x):
outputs = []
for name, module in self.submodule._modules.items():
x = module(x)
if name in self.extracted_layers:
outputs += [x]
return outputs + [x]
参考 https://discuss.pytorch.org/t/how-to-extract-features-of-an-image-from-a-trained-model/119/13
这样做的好处是可以保持网络模块 stateless,即不存储 intermediate states,以免不经意间用到,导致内存一直占用。
方法3. (仅支持 nn.Sequential
模块)不显式地定义网络,直接对原模型进行索引,获取子模块
例如:
from PIL import Image
from torchvision import models, transforms
model = models.vgg16_bn(pretrained=True).features[:37]
img = transforms.Resize((224, 224))(Image.open(img_path).convert('RGB'))
img = transforms.ToTensor()(img)
img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
img = img.view(-1, *template_img.shape)
img = img.to('cuda')
feature = model(img.detach())
方法 2. 不显式地定义新的网络,通过迭代网络模块计算输出,最后返回
例如:
class FeatureExtractor(nn.Module): def __init__(self, submodule, extracted_layers): self.submodule = submodule def forward(self, x): outputs = [] for name, module in self.submodule._modules.items(): x = module(x) if name in self.extracted_layers: outputs += [x] return outputs + [x]
参考 https://discuss.pytorch.org/t/how-to-extract-features-of-an-image-from-a-trained-model/119/13
这样做的好处是可以保持网络模块 stateless,即不存储 intermediate states,以免不经意间用到,导致内存一直占用。
方法二一次只能得到一个batch的输出,当载入下一个batch的时候,之前的数据不是就丢失了吗?
@Hiker01 这里给出的方法都是针对一个 batch 的数据或者一幅图像的,如果需要多个 batch 的处理,需要在外面写迭代。例如有一个 dataloader
提供数据,一个特征提取器 feature_extractor
,会有类似这样的结构,可以把所有 batch 的数据放到 feature_list
里:
feature_list = []
for iter, images in enumerate(some_dataloader):
feature_list += [feature_extractor(images)]
@daa233 @Hiker01 多batch单gpu的话,可以采用“外面写迭代”的方法。 如果你有多个gpu的话,可以用nn.DataParallel,把网络复制到多个gpu上,然后多个batch同时forward得到多batchsize的output