pytorch-grad-cam
pytorch-grad-cam copied to clipboard
ScoreCAM device mismatch error
If I create ScoreCAM with a model on a non-cpu device (eg, 'mps' on Apple Silicon), and pass a tensor on the same device, it raises an error at the line: https://github.com/jacobgil/pytorch-grad-cam/blob/af2d53ad6c80d075946dd99235e8591522b09b3f/pytorch_grad_cam/score_cam.py#L42-L43
because upsampled has device 'cpu'.
I think this can be fixed by simply putting
upsampled = upsampled.to(input_tensor.device)
immediately before the referenced line. (This works for me.)
Here's a colab reproducing the notebook for 'cuda' device: https://colab.research.google.com/drive/1oQkqE2ag6GtFIl80cIJtOXvgtOTKomlI?usp=sharing
All other cam classes work when passed a network and tensor on mps
device on my laptop.