Person_reID_baseline_pytorch
Person_reID_baseline_pytorch copied to clipboard
Class activation heat map.
Is there a way to compute and visualize the class activation heatmaps on the query image (or the resultant reidentified image). Which can tell us at which parts of the image did the model focus more on to generate the final results?
This is my code. I used a different model loader. You need to modify the model loader part to use it.
##################################
# Visualize HearMap by sum
# Zheng, Zhedong, Liang Zheng, and Yi Yang. "A discriminatively learned cnn embedding for person reidentification." ACM Transactions on Multimedia Computing, Communications, and Applications (TOMM) 14, no. 1 (2018): 13.
###################################
import os
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import numpy as np
from model import ft_net, ft_net_dense, ft_net_NAS, PCB, PCB_test
from utils import load_network
import yaml
import argparse
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
from PIL import Image
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--data_dir',default='../Market/pytorch',type=str, help='./test_data')
parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path')
parser.add_argument('--batchsize', default=1, type=int, help='batchsize')
opt = parser.parse_args()
config_path = os.path.join('./model',opt.name,'opts.yaml')
with open(config_path, 'r') as stream:
config = yaml.load(stream)
opt.fp16 = config['fp16']
opt.PCB = config['PCB']
opt.use_dense = config['use_dense']
opt.use_NAS = config['use_NAS']
opt.stride = config['stride']
if 'h' in config:
opt.h = config['h']
opt.w = config['w']
if 'nclasses' in config: # tp compatible with old config files
opt.nclasses = config['nclasses']
else:
opt.nclasses = 751
def heatmap2d(img, arr):
fig = plt.figure()
ax0 = fig.add_subplot(121, title="Image")
ax1 = fig.add_subplot(122, title="Heatmap")
ax0.imshow(Image.open(img))
heatmap = ax1.imshow(arr, cmap='viridis')
fig.colorbar(heatmap)
#plt.show()
fig.savefig('heatmap')
data_transforms = transforms.Compose([
transforms.Resize((opt.h, opt.w), interpolation=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_datasets = {x: datasets.ImageFolder( os.path.join(opt.data_dir,x) ,data_transforms) for x in ['train']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
shuffle=False, num_workers=1) for x in ['train']}
imgpath = image_datasets['train'].imgs
model, _, epoch = load_network(opt.name, opt)
model.classifier.classifier = nn.Sequential()
model = model.eval().cuda()
data = next(iter(dataloaders['train']))
img, label = data
with torch.no_grad():
x = model.model.conv1(img.cuda())
x = model.model.bn1(x)
x = model.model.relu(x)
x = model.model.maxpool(x)
x = model.model.layer1(x)
x = model.model.layer2(x)
output = model.model.layer3(x)
#output = model.model.layer4(x)
print(output.shape)
heatmap = output.squeeze().sum(dim=0).cpu().numpy()
print(heatmap.shape)
#test_array = np.arange(100 * 100).reshape(100, 100)
# Result is saved tas `heatmap.png`
heatmap2d(imgpath[0][0],heatmap)
https://github.com/layumi/Person-reID-verification/issues/4
Thanks @layumi, it worked for me. I am able to generate the heatmaps.
Please find below a snapshot of the heatmap that I generated for my dataset:
But what if I need the heatmaps plotted on the original image as shown in this image taken from one or your project. I think these heatmaps look more realistic:
@Rajat-Mehta One simple way is 0.5original image + 0.5heatmap, and then imshow the combined result.
@layumi What if the dimensions of the original image and heatmap are not same? In my case, the original image is: 256 * 128 and heatmap is 16 * 8. The addition as you suggested won't work in that case.
Please resize the heatmap to adapt the size of original image.
@layumi how to combine result such as 0.5original image + 0.5heatmap after resize the heatmap to the same of original image
@lynnw123 Just add them together and clip the value. (if value > 255, then reset to 255) For example,
combined_result = np.uint8(0.5 * x + 0.5 * y)
The combined image did not look correct:
I did the following changes:
heatmap = np.resize(heatmap, (128,64))
Inside heatmap2d func:
img = np.resize(Image.open(img), (128,64))
combined = np.uint8(0.5*img + 0.5 * arr)
heatmap = ax1.imshow(combined)
@lynnw123 Did you solve this issue ? I mean , are you able to display the heatmap on to of the image and see which places got activated in relation with the image ?
@layumi I was wondering if i can display the 2048x16x8 activation from layer-4 on top of the input image 256x128x3 .. .when i display the activation as 16x8 - the heatmap is not clear(as shown below) it looks the heatmap is stretached out and activated regions are not clear ..
After we have got the overall heatmap, how can we know which part each local branch is focusing at?
Hi @milliema
You may modify the code by replacing the sum
with the index to visualize the heatmap of any specific layer.
Hi @milliema You may modify the code by replacing the
sum
with the index to visualize the heatmap of any specific layer.
Thanks a lot for the quick reply. In my understanding, the fmp after backbone (layer3/4) is of size 24x8x1024(HxWxC), by taking the sum along C dimension, we get the overall heatmap of 24x8. As for the heatmap of local branch, since the patching is conducted on backbone fmp along H dimension, does this mean the top 4x8 part in the overall heatmap correspond to the heatmap of 1st local branch? Is it correct?
Hi @milliema
Yes. If you use PCB
, you evenly split the 24x8
to 6parts of 4x8
.
Otherwise, if the partpooling is 5 parts (could not be divided by 24), Pytorch may split parts with overlappings.