pytorch-grad-cam
pytorch-grad-cam copied to clipboard
Mismatch in the number of attribution slices for 3D volume Torch tensor
I have a custom Conv3D model written in PyTorch and I want to calculate the GradCAM attributions for a 3D volume (representing an mri). The input tensor has input size of (30, 128, 128) and as per below code the attributions have printed shape of (1, 128, 128, 8). I would like your assistance on the mismatch in the number of calculated attribution slices, as I was expecting 30 (instead of 8). You may run the following code for your reference.
import torch
import torch.nn as nn
from pytorch_grad_cam import GradCAM #might need to pip install grad-cam
class Conv3D(nn.Module):
def __init__(self):
super(Conv3D, self).__init__()
self.num_labels = 2
self.classifier = nn.Linear(512, self.num_labels)
self.dropout = nn.Dropout(0.3)
self.in_channels = 1
self.group1 = nn.Sequential(
nn.Conv3d(self.in_channels, 64, kernel_size=3, padding=1),
nn.BatchNorm3d(64),
nn.ReLU(),
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(1, 2, 2)))
self.group2 = nn.Sequential(
nn.Conv3d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm3d(128),
nn.ReLU(),
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)))
self.group3 = nn.Sequential(
nn.Conv3d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm3d(256),
nn.ReLU(),
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)))
self.group4 = nn.Sequential(
nn.Conv3d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm3d(512),
nn.ReLU(),
nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2)))
self.group5 = nn.Sequential(
nn.Conv3d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm3d(512),
nn.ReLU(),
nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1)))
def forward(self, x):
print('input', x.shape)
out = self.group1(x.float()); print('group1', out.shape)
out = self.group2(out); print('group2', out.shape)
out = self.group3(out); print('group3', out.shape)
out = self.group4(out); print('group4', out.shape)
x = self.group5(out); print('group5', x.shape)
y = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2); print('avg layer', y.shape)
logits = self.classifier(y); print('logits', logits.shape)
return logits
model = Conv3D()
mri = torch.rand(30, 128, 128) #a hypothetical mri
cam_instance = GradCAM(model=model,target_layers=[model.group5[0]])
input_tensor = mri.unsqueeze(0).unsqueeze(0) #shape is (1,1,30,128,128)
pixel_attributions = cam_instance(input_tensor=input_tensor)
print()
print('pixel attributions', pixel_attributions.shape) #(1, 128, 128, 8)
It will print:
input torch.Size([1, 1, 30, 128, 128]) group1 torch.Size([1, 64, 29, 64, 64]) group2 torch.Size([1, 128, 14, 32, 32]) group3 torch.Size([1, 256, 7, 16, 16]) group4 torch.Size([1, 512, 3, 8, 8]) group5 torch.Size([1, 512, 2, 5, 5]) avg layer torch.Size([1, 512]) logits torch.Size([1, 2])
pixel attributions (1, 128, 128, 8)