VAN-Segmentation icon indicating copy to clipboard operation
VAN-Segmentation copied to clipboard

Why does 2d tokens become 1d tokens and then become 2d tokens?

Open BeautySilly opened this issue 2 years ago • 1 comments

Hi, I have a question,in the van.py file of VAN-Segmentation:‘why does 2d tokens become 1d tokens and then become 2d tokens?’, in line 105, line 110 and line 223? The details are as follows: In block class:

def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.permute(0, 2, 1).view(B, C, H, W)          # <---------this
        x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
                               * self.attn(self.norm1(x)))
        x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
                               * self.mlp(self.norm2(x)))
        x = x.view(B, C, N).permute(0, 2, 1).               # <---------this
        return x

In VAN class:

def forward(self, x):
        B = x.shape[0]
        outs = []

        for i in range(self.num_stages):
            patch_embed = getattr(self, f"patch_embed{i + 1}")
            block = getattr(self, f"block{i + 1}")
            norm = getattr(self, f"norm{i + 1}")
            x, H, W = patch_embed(x)
            for blk in block:
                x = blk(x, H, W)
            x = norm(x)
            x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()          # <---------this
            outs.append(x)

        return outs

My teacher thinks this has a very profound meaning, and I think the author arbitrarily set it up to keep it consistent with the traditional ViT, there is no special meaning, but my explanation can't convince my teacher, so I seek the author's help. Looking forward to your reply! Best!

BeautySilly avatar Jul 27 '22 05:07 BeautySilly

Because the LayerNorm is performed on the last dimension (which need 1d tokens) and Conv is performed on the last two dimensions (which need 2d tokens)

XuRuihan avatar Feb 21 '23 07:02 XuRuihan