VisionTransformer-Pytorch
VisionTransformer-Pytorch copied to clipboard
How to visualize attention map
Hi,
I want to visualize attention map. I found https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb
In this repo, I did not found vis
option for attention map.
(If any, please let me know and I'd appreciate it.)
So, I decided to add this to model.py
.
like this:
# In VisionTransformer
def forward(self, x):
feat, attn_weights = self.extract_features(x)
# classifier
logits = self.classifier(feat[:, 0])
return logits, attn_weights
# In Encoder
def forward(self, x):
attn_weights = []
out = self.pos_embedding(x)
for layer in self.encoder_layers:
out, weights = layer(out)
attn_weights.append(weights)
out = self.norm(out)
return out, attn_weights
# In SelfAttention
def forward(self, x):
b, n, _ = x.shape
q = self.query(x, dims=([2], [0]))
k = self.key(x, dims=([2], [0]))
v = self.value(x, dims=([2], [0]))
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale
attn_weights = F.softmax(attn_weights, dim=-1)
out = torch.matmul(attn_weights, v)
out = out.permute(0, 2, 1, 3)
out = self.out(out, dims=([2, 3], [0, 1]))
return out, attn_weights
And I got the result.
But I don't know that it is right or not. Because the result of attention map above link is quite different for me. (I used pretrained weights in here).
I am not sure if my results are correct. I would be happy if I could hear the answer.
Thanks.
Looks good to me but one thing you should pay attention to is that vit-model-1 is finetuned on the cassava-leaf-disease-classification task. You may expect to visualize an image from that dataset. It is quite different from object classification and focuses on the low-level texture of the input leaf. To visualize the attention map of a dog, you can utilize pre-trained models here.
Anyway, it is a good first try. I'm still hesitating about the operation of extracting the "attention Map" since I don't want it to affect the inference process, that is, to modify the forward function. Maybe later I will check some best practices about hooks. If u r willing to, u can make a PR of your implement.
Thanks for answer. I used your recommended pre-trained models.
Here is result for a dog.
Original attention map in repo for https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb is below:
It seems to be something, but I'm not sure. What do you think about this part?
If you think this part is okay,
It seems that a simple flag-vis
can minimize the influence of inference.
If vis
is False
,
Model is working original forward function.
For people still looking for a solution, my package NoPdb allows capturing attention weights from pretty much any Transformer implementation without any modifications to the code. See a Colab notebook showing how to do this for ViT (a different implementation).
In this case, it would be something like:
with nopdb.capture_calls(SelfAttention.forward) as calls:
logits = model(x)
calls[0].locals["attn_weights"] # attention weights of the first layer
Hi, when I try to implement the changes by @piantic, this is the error I am getting:
Traceback (most recent call last):
File "C:\Users\Surya\Desktop\Automatic-Pain-Estimation-MQP\scripts\Visualize_Attention_Map.py", line 96, in
feat = self.transformer(emb)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\vision_transformer_pytorch\model.py", line 177, in forward
out, weights = layer(out)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\vision_transformer_pytorch\model.py", line 139, in forward
out = self.dropout(out)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\modules\dropout.py", line 58, in forward
return F.dropout(input, self.p, self.training, self.inplace)
File "C:\Users\Surya\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\nn\functional.py", line 1169, in dropout
return VF.dropout(input, p, training) if inplace else _VF.dropout(input, p, training)
TypeError: dropout(): argument 'input' (position 1) must be Tensor, not tuple
Is there anything else I need to do? I feel that there might be some change that needs to be made in the EncoderBlock part of the model.py file
Hi, @Suryanshg.
This is my example notebook for visualizing attention map using this github. https://www.kaggle.com/code/piantic/vision-transformer-vit-visualize-attention-map/notebook
And you can see visualized version of ViT in below link. https://www.kaggle.com/datasets/piantic/visiontransformerpytorch121
I hope this helps you. Thanks.