Swin-Transformer
Swin-Transformer copied to clipboard
做mask时候不用划分成9份吧, 4份就可以?附验证代码
本质上只要保证新窗口内的各个patch有来源的区分性就可以,作者通过 mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 来得到一个来源图,那完全可以划分成4份就可以了啊。
这是验证代码
import torch
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
window_size = 7
H, W = 56, 56
shift_size = window_size//2
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
#### 以上划分9个窗口
####################################################################################################
#### 以下划分4个窗口
img_mask1 = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices1 = (slice(0, -shift_size),
slice(-shift_size, None))
w_slices1 = (slice(0, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices1:
for w in w_slices1:
img_mask1[:, h, w, :] = cnt
cnt += 1
mask_windows1 = window_partition(img_mask1, window_size) # nW, window_size, window_size, 1
mask_windows1 = mask_windows1.view(-1, window_size * window_size)
attn_mask1 = mask_windows1.unsqueeze(1) - mask_windows1.unsqueeze(2)
attn_mask1 = attn_mask1.masked_fill(attn_mask1 != 0, float(-100.0)).masked_fill(attn_mask1 == 0, float(0.0))
t = attn_mask == attn_mask1
print(t.sum() == t.flatten(0).shape[0])
结果是true,是否说明直接划分四个区域就行了呢?
Me too,最近精读代码想的和你一样
I think it might get more concern from authors if you translate this issue into English.
(建议把issue翻译成英文)
I think it might get more concern from authors if you translate this issue into English.
(建议把issue翻译成英文)
哈哈, 好吧, 我看作者是亚研院那几个国内兄弟, 懒得搞英文了(还是菜), 哈哈。
Me too,最近精读代码想的和你一样
就是分块分多了, 但是还是赞叹构思太巧妙了, 瑕不掩瑜。
大佬能大概说说代码的思路吗,小弟太菜没看明白这里代码,怎么实现的
大佬能大概说说代码的思路吗,小弟太菜没看明白这里代码,怎么实现的
自注意力机制 Q K 相乘不是出来一个矩阵嘛? 然后比如 i行 j列这个元素,代表第i个token和第j个token之间的关系。然后来自不同窗口的两个token应该没关系,所以应该强行置0。
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
这句话就是给 算出来的矩阵标序号, 算出来来自一个窗口为0, 不同窗口不为0。 不为0的给原矩阵对应位置-100, 这样softmax出来这里就接近0, 也就达到了前面说的强行置0的效果.
大佬能大概说说代码的思路吗,小弟太菜没看明白这里代码,怎么实现的
自注意力机制 Q K 相乘不是出来一个矩阵嘛? 然后比如 i行 j列这个元素,代表第i个token和第j个token之间的关系。然后来自不同窗口的两个token应该没关系,所以应该强行置0。
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
这句话就是给 算出来的矩阵标序号, 算出来来自一个窗口为0, 不同窗口不为0。 不为0的给原矩阵对应位置-100, 这样softmax出来这里就接近0, 也就达到了前面说的强行置0的效果.
嗯嗯,谢谢大佬!
@jmjkx 我之前在https://github.com/pytorch/vision/pull/6246 里添加了SwinV2到torchvision里面。
你可以再提一个issue,如果验证9->4不会引起精度降低的话,这简化还是挺有价值的。
@jmjkx 我之前在pytorch/vision#6246 里添加了SwinV2到torchvision里面。
你可以再提一个issue,如果验证9->4不会引起精度降低的话,这简化还是挺有价值的。
好的, 这几天抽时间写个英文的。 在 torchvision repo 提么? 还是在原来微软作者那里提
@jmjkx 我看这个微软的repo好像作者已经不维护了吧。 你可以在torchvision提一个,看看maintainer们愿不愿意接受。
@jmjkx 我看这个微软的repo好像作者已经不维护了吧。 你可以在torchvision提一个,看看maintainer们愿不愿意接受。
好的好的, 谢谢
请问如果我想在kv上做spatial reduction的话,这个mask该怎么变呢?