vit-explain
vit-explain copied to clipboard
normalize (sum to 1) attention score seems not right
Hi Thanks for sharing nice work.
I noticed that you've done normalizing attention score (row sum to 1) as mentioned in the original attention rollout paper.
I = torch.eye(attention_heads_fused.size(-1))
a = (attention_heads_fused + 1.0*I)/2
a = a / a.sum(dim=-1)
But it seems when dividing the summation of row attention score, keepdim=True should be apply to ensure that sum of row attention score after normalization should be 1.
a = a / a.sum(dim=-1,keepdim=True)
Maybe I'm wrong, please double check this issue. Thanks