Swin-Transformer
Swin-Transformer copied to clipboard
Question about the FLOPs
I have two questions about FLOPs.
(1) here
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
I think it is the FLOPs for norm. We know that
``
FLOPs_norm
= self.num_features * H//32 * W//32 = self.num_features * (self.patches_resolution[0]*4) // 32 * (self.patches_resolution[1]*4) // 32 = self.num_features * (self.patches_resolution[0]) // 8 * (self.patches_resolution[1]) // 8 = self.num_features * self.patches_resolution[0] * self.patches_resolution[1]) // 64 = self.num_features * self.patches_resolution[0] * self.patches_resolution[1]) // 2**6 `` however, self.num_layers is 4 So I am confused.
(2)here
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
What is the meaning of "2" in the FLOPS of mlp?
Does anyone understand these problems
Many thanks!
I have the same question of (1), did you figure it out?
About (2), there are 2 layers of mlp in single transformer block, as in class Mlp(nn.Module):
@slacklife for (1) I am also confused.
I have two questions about FLOPs.
(1) here
flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
I think it is the FLOPs for norm. We know that `` FLOPs_norm= self.num_features * H//32 * W//32 = self.num_features * (self.patches_resolution[0]*4) // 32 * (self.patches_resolution[1]*4) // 32 = self.num_features * (self.patches_resolution[0]) // 8 * (self.patches_resolution[1]) // 8 = self.num_features * self.patches_resolution[0] * self.patches_resolution[1]) // 64 = self.num_features * self.patches_resolution[0] * self.patches_resolution[1]) // 2**6 `` however, self.num_layers is 4 So I am confused.
(2)here
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
What is the meaning of "2" in the FLOPS of mlp?Does anyone understand these problems
Many thanks!
I thank you are right,the format should be
flops += self.num_features * (self.patches_resolution[0] // (2**(self.num_layers-1))) * (self.patches_resolution[1] // (2 ** (self.num_layers-1)))