Person_reID_baseline_pytorch icon indicating copy to clipboard operation
Person_reID_baseline_pytorch copied to clipboard

Class activation heat map.

Open Rajat-Mehta opened this issue 5 years ago • 14 comments

Is there a way to compute and visualize the class activation heatmaps on the query image (or the resultant reidentified image). Which can tell us at which parts of the image did the model focus more on to generate the final results?

Rajat-Mehta avatar Jul 01 '19 14:07 Rajat-Mehta

This is my code. I used a different model loader. You need to modify the model loader part to use it.

##################################
# Visualize HearMap by sum
# Zheng, Zhedong, Liang Zheng, and Yi Yang. "A discriminatively learned cnn embedding for person reidentification." ACM Transactions on Multimedia Computing, Communications, and Applications (TOMM) 14, no. 1 (2018): 13.
###################################

import os
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import numpy as np
from model import ft_net, ft_net_dense, ft_net_NAS, PCB, PCB_test
from utils import load_network
import yaml
import argparse
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, models, transforms
from PIL import Image

parser = argparse.ArgumentParser(description='Training')

parser.add_argument('--data_dir',default='../Market/pytorch',type=str, help='./test_data')
parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path')
parser.add_argument('--batchsize', default=1, type=int, help='batchsize')

opt = parser.parse_args()

config_path = os.path.join('./model',opt.name,'opts.yaml')
with open(config_path, 'r') as stream:
        config = yaml.load(stream)
opt.fp16 = config['fp16']
opt.PCB = config['PCB']
opt.use_dense = config['use_dense']
opt.use_NAS = config['use_NAS']
opt.stride = config['stride']

if 'h' in config:
    opt.h = config['h']
    opt.w = config['w']

if 'nclasses' in config: # tp compatible with old config files
    opt.nclasses = config['nclasses']
else:
    opt.nclasses = 751


def heatmap2d(img, arr):
    fig = plt.figure()
    ax0 = fig.add_subplot(121, title="Image")
    ax1 = fig.add_subplot(122, title="Heatmap")

    ax0.imshow(Image.open(img))
    heatmap = ax1.imshow(arr, cmap='viridis')
    fig.colorbar(heatmap)
    #plt.show()
    fig.savefig('heatmap')

data_transforms = transforms.Compose([
        transforms.Resize((opt.h, opt.w), interpolation=3),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

image_datasets = {x: datasets.ImageFolder( os.path.join(opt.data_dir,x) ,data_transforms) for x in ['train']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                             shuffle=False, num_workers=1) for x in ['train']}

imgpath = image_datasets['train'].imgs
model, _, epoch = load_network(opt.name, opt)
model.classifier.classifier = nn.Sequential()
model = model.eval().cuda()

data = next(iter(dataloaders['train']))
img, label = data
with torch.no_grad():
    x = model.model.conv1(img.cuda())
    x = model.model.bn1(x)
    x = model.model.relu(x)
    x = model.model.maxpool(x)
    x = model.model.layer1(x)
    x = model.model.layer2(x)
    output = model.model.layer3(x)
    #output = model.model.layer4(x)

print(output.shape)
heatmap = output.squeeze().sum(dim=0).cpu().numpy()
print(heatmap.shape)
#test_array = np.arange(100 * 100).reshape(100, 100)
# Result is saved tas `heatmap.png`
heatmap2d(imgpath[0][0],heatmap)

layumi avatar Jul 01 '19 23:07 layumi

https://github.com/layumi/Person-reID-verification/issues/4

layumi avatar Jul 01 '19 23:07 layumi

Thanks @layumi, it worked for me. I am able to generate the heatmaps.

Please find below a snapshot of the heatmap that I generated for my dataset:

heatmap

But what if I need the heatmaps plotted on the original image as shown in this image taken from one or your project. I think these heatmaps look more realistic:

Screenshot from 2019-07-04 18-05-00

Rajat-Mehta avatar Jul 04 '19 16:07 Rajat-Mehta

@Rajat-Mehta One simple way is 0.5original image + 0.5heatmap, and then imshow the combined result.

layumi avatar Jul 05 '19 00:07 layumi

@layumi What if the dimensions of the original image and heatmap are not same? In my case, the original image is: 256 * 128 and heatmap is 16 * 8. The addition as you suggested won't work in that case.

Rajat-Mehta avatar Jul 05 '19 18:07 Rajat-Mehta

Please resize the heatmap to adapt the size of original image.

layumi avatar Jul 05 '19 23:07 layumi

@layumi how to combine result such as 0.5original image + 0.5heatmap after resize the heatmap to the same of original image

lynnw123 avatar Sep 23 '19 17:09 lynnw123

@lynnw123 Just add them together and clip the value. (if value > 255, then reset to 255) For example,

combined_result = np.uint8(0.5 * x + 0.5 * y)

layumi avatar Sep 24 '19 10:09 layumi

The combined image did not look correct:

1

I did the following changes: heatmap = np.resize(heatmap, (128,64)) Inside heatmap2d func:
img = np.resize(Image.open(img), (128,64)) combined = np.uint8(0.5*img + 0.5 * arr) heatmap = ax1.imshow(combined)

lynnw123 avatar Sep 24 '19 23:09 lynnw123

@lynnw123 Did you solve this issue ? I mean , are you able to display the heatmap on to of the image and see which places got activated in relation with the image ?

@layumi I was wondering if i can display the 2048x16x8 activation from layer-4 on top of the input image 256x128x3 .. .when i display the activation as 16x8 - the heatmap is not clear(as shown below) it looks the heatmap is stretached out and activated regions are not clear ..

cam_sample_market_1

bmiftah avatar Dec 16 '19 11:12 bmiftah

After we have got the overall heatmap, how can we know which part each local branch is focusing at?

milliema avatar May 16 '20 10:05 milliema

Hi @milliema You may modify the code by replacing the sum with the index to visualize the heatmap of any specific layer.

layumi avatar May 16 '20 10:05 layumi

Hi @milliema You may modify the code by replacing the sum with the index to visualize the heatmap of any specific layer.

Thanks a lot for the quick reply. In my understanding, the fmp after backbone (layer3/4) is of size 24x8x1024(HxWxC), by taking the sum along C dimension, we get the overall heatmap of 24x8. As for the heatmap of local branch, since the patching is conducted on backbone fmp along H dimension, does this mean the top 4x8 part in the overall heatmap correspond to the heatmap of 1st local branch? Is it correct?

milliema avatar May 16 '20 10:05 milliema

Hi @milliema Yes. If you use PCB, you evenly split the 24x8 to 6parts of 4x8. Otherwise, if the partpooling is 5 parts (could not be divided by 24), Pytorch may split parts with overlappings.

layumi avatar May 18 '20 00:05 layumi