pytorch-grad-cam
pytorch-grad-cam copied to clipboard
float16 is not supported (src data type = 23 is not supported)
trafficstars
I'm opening a new issue to revive this one: https://github.com/jacobgil/pytorch-grad-cam/issues/198
This is the error I'm getting, it's the same as the issue above:
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
File ".../lib/python3.8/site-packages/pytorch_grad_cam/base_cam.py", line 192, in __call__
return self.forward(input_tensor,
File ".../lib/python3.8/site-packages/pytorch_grad_cam/base_cam.py", line 105, in forward
cam_per_layer = self.compute_cam_per_layer(input_tensor,
File ".../lib/python3.8/site-packages/pytorch_grad_cam/base_cam.py", line 144, in compute_cam_per_layer
scaled = scale_cam_image(cam, target_size)
File ".../lib/python3.8/site-packages/pytorch_grad_cam/utils/image.py", line 169, in scale_cam_image
img = cv2.resize(img, target_size)
cv2.error: OpenCV(4.7.0) :-1: error: (-5:Bad argument) in function 'resize'
> Overload resolution failed:
> - src data type = 23 is not supported
> - Expected Ptr<cv::UMat> for argument 'src'
I'm running 1.5:
Name: grad-cam
Version: 1.5.0
Summary: Many Class Activation Map methods implemented in Pytorch for classification, segmentation, object detection and more
Home-page: https://github.com/jacobgil/pytorch-grad-cam
Author: Jacob Gildenblat
Author-email: [email protected]
License:
Location: .../lib/python3.8/site-packages
Requires: matplotlib, numpy, opencv-python, Pillow, scikit-learn, torch, torchvision, tqdm, ttach
The issue is my input is a torch.HalfTensor (float16) which opencv doesn't seem to support for resize.
The fix suggestion is as such:
def scale_cam_image(cam, target_size=None):
result = []
for img in cam:
img = img - np.min(img)
img = img / (1e-7 + np.max(img))
if target_size is not None:
img = cv2.resize(img.astype(np.float32), target_size)
result.append(img)
result = np.float32(result)
return result
Converting to a float as suggested shouldn't be an issue given the resulting array gets converted to a float anyways and should cover all users' use cases.
I'd love to put in a PR for the fix above if possible. This may have to do with using a specific version of opencv but I think this fix is desirable for compatibility.