Swin-Transformer icon indicating copy to clipboard operation
Swin-Transformer copied to clipboard

做mask时候不用划分成9份吧, 4份就可以?附验证代码

Open jmjkx opened this issue 2 years ago • 12 comments

本质上只要保证新窗口内的各个patch有来源的区分性就可以,作者通过 mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 来得到一个来源图,那完全可以划分成4份就可以了啊。

489d3d064a5c802c33e0e66c4a6ddde 这是验证代码

import torch


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows



window_size = 7
H, W = 56, 56
shift_size = window_size//2


img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
w_slices = (slice(0, -window_size),
            slice(-window_size, -shift_size),
            slice(-shift_size, None))
cnt = 0
for h in h_slices:
    for w in w_slices:
        img_mask[:, h, w, :] = cnt
        cnt += 1

mask_windows = window_partition(img_mask, window_size)  # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

#### 以上划分9个窗口
####################################################################################################
#### 以下划分4个窗口

img_mask1 = torch.zeros((1, H, W, 1))  # 1 H W 1
h_slices1 = (slice(0, -shift_size),
            slice(-shift_size, None))
w_slices1 = (slice(0, -shift_size),
            slice(-shift_size, None))
cnt = 0
for h in h_slices1:
    for w in w_slices1:
        img_mask1[:, h, w, :] = cnt
        cnt += 1

mask_windows1 = window_partition(img_mask1, window_size)  # nW, window_size, window_size, 1
mask_windows1 = mask_windows1.view(-1, window_size * window_size)
attn_mask1 = mask_windows1.unsqueeze(1) - mask_windows1.unsqueeze(2)
attn_mask1 = attn_mask1.masked_fill(attn_mask1 != 0, float(-100.0)).masked_fill(attn_mask1 == 0, float(0.0))
t = attn_mask == attn_mask1
print(t.sum() == t.flatten(0).shape[0])

结果是true,是否说明直接划分四个区域就行了呢?

jmjkx avatar Apr 20 '22 02:04 jmjkx

Me too,最近精读代码想的和你一样

lifan724 avatar May 03 '22 09:05 lifan724

I think it might get more concern from authors if you translate this issue into English.

(建议把issue翻译成英文)

ain-soph avatar Jul 07 '22 19:07 ain-soph

I think it might get more concern from authors if you translate this issue into English.

(建议把issue翻译成英文)

哈哈, 好吧, 我看作者是亚研院那几个国内兄弟, 懒得搞英文了(还是菜), 哈哈。

jmjkx avatar Aug 11 '22 09:08 jmjkx

Me too,最近精读代码想的和你一样

就是分块分多了, 但是还是赞叹构思太巧妙了, 瑕不掩瑜。

jmjkx avatar Aug 11 '22 09:08 jmjkx

大佬能大概说说代码的思路吗,小弟太菜没看明白这里代码,怎么实现的

CHENHUI-X avatar Aug 27 '22 06:08 CHENHUI-X

大佬能大概说说代码的思路吗,小弟太菜没看明白这里代码,怎么实现的

自注意力机制 Q K 相乘不是出来一个矩阵嘛? 然后比如 i行 j列这个元素,代表第i个token和第j个token之间的关系。然后来自不同窗口的两个token应该没关系,所以应该强行置0。

attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 

这句话就是给 算出来的矩阵标序号, 算出来来自一个窗口为0, 不同窗口不为0。 不为0的给原矩阵对应位置-100, 这样softmax出来这里就接近0, 也就达到了前面说的强行置0的效果.

jmjkx avatar Sep 01 '22 20:09 jmjkx

大佬能大概说说代码的思路吗,小弟太菜没看明白这里代码,怎么实现的

自注意力机制 Q K 相乘不是出来一个矩阵嘛? 然后比如 i行 j列这个元素,代表第i个token和第j个token之间的关系。然后来自不同窗口的两个token应该没关系,所以应该强行置0。

attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 

这句话就是给 算出来的矩阵标序号, 算出来来自一个窗口为0, 不同窗口不为0。 不为0的给原矩阵对应位置-100, 这样softmax出来这里就接近0, 也就达到了前面说的强行置0的效果.

嗯嗯,谢谢大佬!

CHENHUI-X avatar Sep 02 '22 07:09 CHENHUI-X

@jmjkx 我之前在https://github.com/pytorch/vision/pull/6246 里添加了SwinV2到torchvision里面。

你可以再提一个issue,如果验证9->4不会引起精度降低的话,这简化还是挺有价值的。

ain-soph avatar Oct 10 '22 22:10 ain-soph

@jmjkx 我之前在pytorch/vision#6246 里添加了SwinV2到torchvision里面。

你可以再提一个issue,如果验证9->4不会引起精度降低的话,这简化还是挺有价值的。

好的, 这几天抽时间写个英文的。 在 torchvision repo 提么? 还是在原来微软作者那里提

jmjkx avatar Oct 24 '22 08:10 jmjkx

@jmjkx 我看这个微软的repo好像作者已经不维护了吧。 你可以在torchvision提一个,看看maintainer们愿不愿意接受。

ain-soph avatar Oct 24 '22 08:10 ain-soph

@jmjkx 我看这个微软的repo好像作者已经不维护了吧。 你可以在torchvision提一个,看看maintainer们愿不愿意接受。

好的好的, 谢谢

jmjkx avatar Oct 24 '22 08:10 jmjkx

请问如果我想在kv上做spatial reduction的话,这个mask该怎么变呢?

ResetSun avatar Aug 07 '23 06:08 ResetSun