pytorch-grad-cam icon indicating copy to clipboard operation
pytorch-grad-cam copied to clipboard

Gradients and activations shape in the VIT example code

Open lunaryan opened this issue 8 months ago • 6 comments

I tried to run the vit example code, however, I noticed the following errors.

Python 3.9.18 (main, Sep 11 2023, 13:41:44) [GCC 11.2.0] :: Anaconda, Inc. on linux Type "help", "copyright", "credits" or "license" for more information. >>> from pytorch_grad_cam import GradCAM >>> import torch >>> model = torch.hub.load('facebookresearch/deit:main', ... 'deit_tiny_patch16_224', pretrained=True) Using cache found in /home/.cache/torch/hub/facebookresearch_deit_main >>> target_layers = [model.blocks[-1].norm1] >>> image = torch.rand(1,3,224,224) >>> cam = GradCAM(model=model, target_layers=target_layers) >>> grayscale_cam = cam(input_tensor=image, targets=None) (1, 197, 192) Traceback (most recent call last): File "", line 1, in File "/data4/user/miniconda3/envs/anti-dreambooth/lib/python3.9/site-packages/pytorch_grad_cam/base_cam.py", line 188, in call return self.forward(input_tensor, targets, eigen_smooth) File "/data4/user/miniconda3/envs/anti-dreambooth/lib/python3.9/site-packages/pytorch_grad_cam/base_cam.py", line 112, in forward cam_per_layer = self.compute_cam_per_layer(input_tensor, targets, eigen_smooth) File "/data4/user/miniconda3/envs/anti-dreambooth/lib/python3.9/site-packages/pytorch_grad_cam/base_cam.py", line 143, in compute_cam_per_layer cam = self.get_cam_image(input_tensor, target_layer, targets, layer_activations, layer_grads, eigen_smooth) File "/data4/user/miniconda3/envs/anti-dreambooth/lib/python3.9/site-packages/pytorch_grad_cam/base_cam.py", line 66, in get_cam_image weights = self.get_cam_weights(input_tensor, target_layer, targets, activations, grads) File "/data4/user/miniconda3/envs/anti-dreambooth/lib/python3.9/site-packages/pytorch_grad_cam/grad_cam.py", line 32, in get_cam_weights raise ValueError("Invalid grads shape." ValueError: Invalid grads shape.Shape of grads should be 4 (2D image) or 5 (3D image).

If I comment out the error throwing logic in grad_cam.py, then the shape check in https://github.com/jacobgil/pytorch-grad-cam/blob/1ff3f58818baa2889f3f51d0b9759783b4333ba0/pytorch_grad_cam/base_cam.py#L74 also fails.

Does the shape really matter? Is there a way to fix this or just work around it? Thanks!

pytorch '2.1.0+cu121', grad-cam 1.5.2, ubuntu 18.04

lunaryan avatar Jun 27 '24 03:06 lunaryan