MixFormer
MixFormer copied to clipboard
About Visualizing Attention Maps
The code about the visualizing attention maps in mixformer_online.py reports an error: RuntimeError: shape '[8,8,4,4]' is invalid for input of size 2048.
Here, I mainly want to consult about the meanings of q_w, k_w, skip_len, etc., and why attn_weights[::3] when visualizing the attention weights of online_template-to-template, and attn_weights[1::3] when visualizing the attention weights of template-to-online_template?
Looking forward to your answer.
I am actually facing the same issue could you solve this problem ?
代码改成这样,跑出的图片跟作者的差距比较大,attn[...,skip_len:(skip_len +k_w **2),skip_len:(skip_len + q_w **2)]
attn[...,skip_len:(skip_len +k_w **2)]原代码好像就有问题
Actually I have deleted this line (it is not the best solution maybe) and had no problem running the code. Hopefully we can get the right solution
实际上我已经删除了这一行(它可能不是最好的解决方案),并且运行代码没有问题。希望我们能得到正确的解决方案
Hello, can you share your code
# for attn in attn_weights: # # attn_weights_mean.append(attn[..., skip_len:(skip_len+k_w**2)].mean(dim=1).squeeze().cpu()) # attn_weights_mean.append(attn[..., skip_len:(skip_len+k_w**2)].mean(dim=1).squeeze().reshape(shape1+shape2).cpu())
just removed those lines
You can try to modify it like this to get attn maps of (s2ot, s2s, s2t), but it is not the best solution:
for attn in attn_weights:
try:
attn_weights_mean.append(attn[..., skip_len:(skip_len+k_w**2)].mean(dim=1).squeeze().reshape(shape1+shape2).cpu())
except:
pass
note that it will fail to get the attn maps of (ot2t, t2ot).
Hopefully, we can get the right solution.