Swin-Transformer icon indicating copy to clipboard operation
Swin-Transformer copied to clipboard

Equivalence between Patch Merging and Conv.

Open AHHOZP opened this issue 3 years ago • 4 comments

Hello, after looking at the code in patch merging part, we found the complex operation that slice the feature and concatenate them then go through the linear layer to reduce dimension from 4C to 2C is completely equal to a conv layer of kernel size 2 and stride 2.

The operation you did is concatenate 4 pixels from a 2x2 patch in to 1 pixel, but quadrupled channel. Every 2x2 patch shared the same weight with other patches in your linear layer (self.reduction). The conv(kernel size=2, stride=2) does the same thing.

Amount of parameters of this linear layer is equal to this conv layer. linear layer params = input channel * output channel = 4C * 2C = 8 * C^2 conv layer params = kernel size * kernel size * input channel * output channel = 2 * 2 * C * (2 * C) = 8 * C^2 SO, linear layer params == conv layer params

AHHOZP avatar Aug 29 '22 09:08 AHHOZP

I agree with you that both implementations are essentially the same.

Standard Implementation Using Slicing Operation

Parameters:

self.reduction = nn.Linear(4*C, 2*C, bias=False)

Forward Pass:

x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
x = x.view(B, -1, 4 * C)  # [B, HW/4, 4*C]

x = self.norm(x)
x = self.reduction(x)

The first approach employs slicing, which may appear intricate initially. However, after optimization, it introduces minimal overhead. Nevertheless,

  • it involves complex slicing operations, leading to more concise code.
  • while not strictly necessary, it introduces the concept of PatchMerge. Nevertheless, using a kernel size of 2 and a stride of 2 for downsampling is intuitively straightforward and aligns well with common understanding.

Alternative Implementation Using Conv2d

Parameters:

self.reduction = nn.Conv2d(C, C*2, kernel_size=2, stride=2, bias=False)

Forward Pass:

# forward
# [B, H*W, C] -> [B, H, W, C] -> [B, C, H, W]
x = x.view(B, H, W, C).permute(0, 3, 1, 2)
x = self.norm(x)
x = self.reduction(x)
x = x.permute(0, 2, 3, 1).view(B, -1, C*2)

The second approach involves Conv2d, which appears more elegant; however,

  • the permute operation might introduce potential additional movement overhead.
  • it's noteworthy that no contiguous operation is necessary before the view operation, indicating that extra movement overhead might not be present.

yan-mingyuan avatar Aug 11 '23 01:08 yan-mingyuan

You are right, but the tensor slice and concat also cost time. I dont know how much, and which cost more when compared to contiguous. Another problem is layernorm, the code for 'Using Conv2d', you will find the norm is different from origin implementation. And batchnorm is used for cnn instead of layernorm, so I guess this may be another reason why they use patch mergy instead of cnn.

AHHOZP avatar Aug 15 '23 02:08 AHHOZP

I agree with your perspective. Whether we employ tensor slicing or concatenation, both operations involve accessing non-contiguous memory regions, inevitably leading to non-locality that can impact performance.

Shifting our focus to the matter of normalization. While BatchNorm2d is common in computer vision, LayerNorm, typical in transformer-based NLP tasks, is used here, possibly due to the architecture's influence. This choice shouldn't heavily impact normalization layer behavior, as it's separate from the decision between tensor transposition and patch merging.

yan-mingyuan avatar Aug 15 '23 14:08 yan-mingyuan