pytorch-grad-cam icon indicating copy to clipboard operation
pytorch-grad-cam copied to clipboard

meet a problem when apply in Unet

Open tjboise opened this issue 2 years ago • 2 comments

Hi sir,

I'm quite new in python and very interested in your works! but I met a problem when using the CAM into the Unet. It showed an error in

grayscale_cam = cam(input_tensor=img_tensor, targets=targets)[0,:]

RuntimeError: The size of tensor a (256) must match the size of tensor b (3) at non-singleton dimension 2

do you know how to do with it?

target_layers = [net.down4.maxpool_conv]
        with GradCAM(model=net, target_layers=target_layers, use_cuda=torch.cuda.is_available()) as cam:
            for image_id in tqdm(image_ids):

                image_path = os.path.join(test_dir, image_id + ".jpg")
                label_path = os.path.join(gt_dir, image_id + ".png")

                label = cv2.imread(label_path)
                img = cv2.imread(image_path)

                origin_shape = img.shape
                # print(origin_shape)

                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
                img = cv2.resize(img, (256, 256))
                label = cv2.resize(label, (256,256))
                label = (label/255).astype(int)
                print(type(label))

                img = img.reshape(1, 1, img.shape[0], img.shape[1])

                img_tensor = torch.from_numpy(img)

                img_tensor = img_tensor.to(device=device, dtype=torch.float32)

                pred = net(img_tensor)

                targets = [SemanticSegmentationTarget(0, label)]
                grayscale_cam = cam(input_tensor=img_tensor, targets=targets)[0,:]
                cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True)
                plt.imshow(cam_image)
                plt.show()



tjboise avatar Jul 08 '22 18:07 tjboise

Can you please print(pred.size() and img_tensor.size()) ? Something seems wrong with the shapes.

Also - SemanticSegmentationTarget in the notebook recieves the label and then a mask, What is the SemanticSegmentationTarget you're using ?

And any other details you can share will help diagnose this.

jacobgil avatar Jul 09 '22 16:07 jacobgil

The size of pred and img_tensor are both : torch.Size([1, 1, 256, 256])

I'm quite not sure the input type of the SemanticSegmentationTarget(0,label).

In my case, the type of label is <class 'numpy.ndarray'> with 0 and 1. ( The label is a binary image, and I convert it to ndarray) .

For the first parameter of the SemanticSegmentationTarget(0,label), what kind of the data type should I input? because I think the label only contains 0 and 1, I set the first parameter as 0, is that right?

tjboise avatar Jul 09 '22 22:07 tjboise