pytorch-grad-cam
pytorch-grad-cam copied to clipboard
Can GradCAM be used in Transformer?
GradCAM is initially devised for CNNs, but can GradCAM be available for Transformer or some other architectures with self-attention?
You can use GradCAM in transformers by reshaping the intermediate activations into CNN-like 4D tensors. There is a parameter in, I think, every implemented method on the library called reshape_transform
. You can give it a simple batch+2D tensor to batch+3D tensor reshaping function. There is an example in the wiki I think, I use this:
def reshape_transform(tensor, height=14, width=14):
result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))
# Bring the channels to the first dimension,
# like in CNNs.
result = result.transpose(2, 3).transpose(1, 2)
return result
Edit: You can find this exact function in the wiki
There are also many examples with different transformer variants here: https://jacobgil.github.io/pytorch-gradcam-book/HuggingFace.html