PyTorch-Pretrained-ViT
PyTorch-Pretrained-ViT copied to clipboard
Visualizing attention map
Hi. Does anyone know how we can have access to attention maps?
I'm trying to figure out the same thing
Using the below code I was able to visualize the attention maps.
Step 1:
In transformer.py under class MultiHeadedSelfAttention(nn.Module):
replace the forward method with the below code
def forward(self, x, mask): """ x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) mask : (B(batch_size) x S(seq_len)) * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W """ # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v]) # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) if mask is not None: mask = mask[:, None, None, :].float() scores -= 10000.0 * (1.0 - mask) scores = self.drop(F.softmax(scores, dim=-1)) # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) h = (scores @ v).transpose(1, 2).contiguous() # -merge-> (B, S, D) h = merge_last(h, 2) self.scores = scores return h
Step 2:
In the Transformer.py under class Transformer(nn.Module)
replace the forward method with the below code
def forward(self, x, mask=None): atten_scores = [] for block in self.blocks: x = block(x, mask) atten_scores.append(block.attn.scores) return x,atten_scores
Step 3: In model.py under class 'class ViT(nn.Module)' replace the forward method with the below code
def forward(self, x): b, c, fh, fw = x.shape x = self.patch_embedding(x) # b,d,gh,gw x = x.flatten(2).transpose(1, 2) # b,gh*gw,d if hasattr(self, 'class_token'): x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1) # b,gh*gw+1,d if hasattr(self, 'positional_embedding'): x = self.positional_embedding(x) # b,gh*gw+1,d x,atten_scores = self.transformer(x) # b,gh*gw+1,d att_mat = torch.stack(atten_scores).squeeze(1) att_mat = torch.mean(att_mat, dim=1) # print("att_mat",att_mat.shape) if hasattr(self, 'pre_logits'): x = self.pre_logits(x) x = torch.tanh(x) if hasattr(self, 'fc'): x = self.norm(x)[:, 0] # b,d x = self.fc(x) # b,num_classes return x,att_mat
Step 4:
Now in forward pass will return output of MLP layer and the activation map.
x,atten_weights = model.forward(input_image.unsqueeze(0))
here atten_weights will contain the activation maps
Step 5: Iterate through each atten_weights and visualize those
from PIL import Image import matplotlib.pyplot as plt im = Image.open(img_pth)
for att_mat in atten_weights: residual_att = torch.eye(att_mat.size(1)) aug_att_mat = att_mat + residual_att aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1) joint_attentions = torch.zeros(aug_att_mat.size()) joint_attentions[0] = aug_att_mat[0] for n in range(1, aug_att_mat.size(0)): joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n - 1]) v = joint_attentions grid_size = int(np.sqrt(aug_att_mat.size(-1))) mask = v[0,1:].reshape(grid_size, grid_size).detach().numpy() mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis] result = (mask * im).astype("uint8") fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16)) ax1.set_title('Original') ax2.set_title('Attention Map') _ = ax1.imshow(im) _ = ax2.imshow(result)
Could you please share final code or any colab demo for extract attention map @gouttham gouttham
Could you please share final code or any colab demo for extract attention map @gouttham gouttham
https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb