pytorch-grad-cam
pytorch-grad-cam copied to clipboard
meet a problem when apply in Unet
Hi sir,
I'm quite new in python and very interested in your works! but I met a problem when using the CAM into the Unet. It showed an error in
grayscale_cam = cam(input_tensor=img_tensor, targets=targets)[0,:]
RuntimeError: The size of tensor a (256) must match the size of tensor b (3) at non-singleton dimension 2
do you know how to do with it?
target_layers = [net.down4.maxpool_conv]
with GradCAM(model=net, target_layers=target_layers, use_cuda=torch.cuda.is_available()) as cam:
for image_id in tqdm(image_ids):
image_path = os.path.join(test_dir, image_id + ".jpg")
label_path = os.path.join(gt_dir, image_id + ".png")
label = cv2.imread(label_path)
img = cv2.imread(image_path)
origin_shape = img.shape
# print(origin_shape)
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
img = cv2.resize(img, (256, 256))
label = cv2.resize(label, (256,256))
label = (label/255).astype(int)
print(type(label))
img = img.reshape(1, 1, img.shape[0], img.shape[1])
img_tensor = torch.from_numpy(img)
img_tensor = img_tensor.to(device=device, dtype=torch.float32)
pred = net(img_tensor)
targets = [SemanticSegmentationTarget(0, label)]
grayscale_cam = cam(input_tensor=img_tensor, targets=targets)[0,:]
cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True)
plt.imshow(cam_image)
plt.show()
Can you please print(pred.size() and img_tensor.size()) ? Something seems wrong with the shapes.
Also - SemanticSegmentationTarget in the notebook recieves the label and then a mask, What is the SemanticSegmentationTarget you're using ?
And any other details you can share will help diagnose this.
The size of pred and img_tensor are both :
torch.Size([1, 1, 256, 256])
I'm quite not sure the input type of the SemanticSegmentationTarget(0,label).
In my case, the type of label is <class 'numpy.ndarray'>
with 0 and 1. ( The label is a binary image, and I convert it to ndarray) .
For the first parameter of the SemanticSegmentationTarget(0,label), what kind of the data type should I input? because I think the label only contains 0 and 1, I set the first parameter as 0, is that right?