FasterNet icon indicating copy to clipboard operation
FasterNet copied to clipboard

Completely replace Resnet18 with FasterNet?

Open Embarrassing1 opened this issue 1 year ago • 2 comments

Excellent work! I have a need for a ResNet18 network boost. Should I use FasterNet to replace Resnet18 directly or should I use Pconv to replace part of the convolution of Resnet18, have you done any similar work? Here's the code I rewrote, but it doesn't seem to improve the speed much

class Partial_conv3(nn.Module): 

    def __init__(self, dim, n_div):
        super().__init__()
        self.dim_conv3 = dim // n_div
        self.dim_untouched = dim - self.dim_conv3
        # (in_channels, in_channels, kernel_size, stride, padding, dilation, groups, bias)
        self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)

    def forward(self, x: Tensor) -> Tensor:
        # for training/inference
        x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
        x1 = self.partial_conv3(x1)
        x = torch.cat((x1, x2), 1)

        return x

class FasterNetBlock(nn.Module):
    def __init__(self,
                 dim,
                 drop_path=0,
                 n_div=4,
                 ratio=2.
                 ):
        super().__init__()
        self.dim = dim
        self.mlp_ratio = ratio
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.n_div = n_div

        mlp_hidden_dim = int(dim * ratio)

        mlp_layer: List[nn.Module] = [
            nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False), # Conv1x1
            nn.BatchNorm2d(mlp_hidden_dim), # BN layer
            nn.ReLU(), # ReLU layer
            nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False) # Conv1x1
        ]

        self.mlp = nn.Sequential(*mlp_layer)

        self.spatial_mixing = Partial_conv3(
            dim,
            n_div,
        )

    def forward(self, x: Tensor) -> Tensor:
        shortcut = x
        x = self.spatial_mixing(x) # PConv
        x = shortcut + self.drop_path(self.mlp(x)) # PConv:x + [Conv1x1:x->BN:x->ReLU:x->Conv1x1:x]
        return x
    
class PatchEmbed(nn.Module):

    def __init__(self, in_chans, embed_dim, patch_size=5, patch_stride=2, bias=False):
        super().__init__()
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=2, bias=bias)
        self.norm = nn.BatchNorm2d(embed_dim)

    def forward(self, x: Tensor) -> Tensor:
        x = self.norm(self.proj(x))
        return x


class PatchMerging(nn.Module):

    def __init__(self, in_chans, merge_dim, patch_size2=2, patch_stride2=2):
        super().__init__()
        self.reduction = nn.Conv2d(in_chans, merge_dim, kernel_size=patch_size2, stride=patch_stride2, bias=False)
        self.norm = nn.BatchNorm2d(merge_dim)

    def forward(self, x: Tensor) -> Tensor:
        x = self.norm(self.reduction(x))
        return x

class FasterNet(nn.Module):
    def __init__(self, input_channels):
        super().__init__()
        
        self.num_ch_enc = np.array([64, 64, 128, 256, 512])
        self.num_ch_enc_ = np.array([64+32, 64+32, 128+64, 256+128, 512+256])

        self.fb1 = FasterNetBlock(64)
        self.fb2 = FasterNetBlock(64)
        self.fb3 = FasterNetBlock(128)
        self.fb4 = FasterNetBlock(256)

        self.embed = PatchEmbed(input_channels, 64)
        self.merge1 = PatchMerging(64, 64)
        self.merge2 = PatchMerging(64, 128)
        self.merge3 = PatchMerging(128, 256)
        self.merge4 = PatchMerging(256, 512)

    def forward(self, x: Tensor) -> Tensor:
        features = [] 
        features.append(self.embed(x)) # layer0
        features.append(self.merge1(self.fb1(features[-1]))) # layer1
        features.append(self.merge2(self.fb2(features[-1]))) # layer2
        features.append(self.merge3(self.fb3(features[-1]))) # layer3
        features.append(self.merge4(self.fb4(features[-1]))) # layer4

        return features

Embarrassing1 avatar Jul 11 '23 09:07 Embarrassing1