FasterNet
FasterNet copied to clipboard
Completely replace Resnet18 with FasterNet?
Excellent work! I have a need for a ResNet18 network boost. Should I use FasterNet to replace Resnet18 directly or should I use Pconv to replace part of the convolution of Resnet18, have you done any similar work? Here's the code I rewrote, but it doesn't seem to improve the speed much
class Partial_conv3(nn.Module):
def __init__(self, dim, n_div):
super().__init__()
self.dim_conv3 = dim // n_div
self.dim_untouched = dim - self.dim_conv3
# (in_channels, in_channels, kernel_size, stride, padding, dilation, groups, bias)
self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)
def forward(self, x: Tensor) -> Tensor:
# for training/inference
x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
x1 = self.partial_conv3(x1)
x = torch.cat((x1, x2), 1)
return x
class FasterNetBlock(nn.Module):
def __init__(self,
dim,
drop_path=0,
n_div=4,
ratio=2.
):
super().__init__()
self.dim = dim
self.mlp_ratio = ratio
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.n_div = n_div
mlp_hidden_dim = int(dim * ratio)
mlp_layer: List[nn.Module] = [
nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False), # Conv1x1
nn.BatchNorm2d(mlp_hidden_dim), # BN layer
nn.ReLU(), # ReLU layer
nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False) # Conv1x1
]
self.mlp = nn.Sequential(*mlp_layer)
self.spatial_mixing = Partial_conv3(
dim,
n_div,
)
def forward(self, x: Tensor) -> Tensor:
shortcut = x
x = self.spatial_mixing(x) # PConv
x = shortcut + self.drop_path(self.mlp(x)) # PConv:x + [Conv1x1:x->BN:x->ReLU:x->Conv1x1:x]
return x
class PatchEmbed(nn.Module):
def __init__(self, in_chans, embed_dim, patch_size=5, patch_stride=2, bias=False):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, padding=2, bias=bias)
self.norm = nn.BatchNorm2d(embed_dim)
def forward(self, x: Tensor) -> Tensor:
x = self.norm(self.proj(x))
return x
class PatchMerging(nn.Module):
def __init__(self, in_chans, merge_dim, patch_size2=2, patch_stride2=2):
super().__init__()
self.reduction = nn.Conv2d(in_chans, merge_dim, kernel_size=patch_size2, stride=patch_stride2, bias=False)
self.norm = nn.BatchNorm2d(merge_dim)
def forward(self, x: Tensor) -> Tensor:
x = self.norm(self.reduction(x))
return x
class FasterNet(nn.Module):
def __init__(self, input_channels):
super().__init__()
self.num_ch_enc = np.array([64, 64, 128, 256, 512])
self.num_ch_enc_ = np.array([64+32, 64+32, 128+64, 256+128, 512+256])
self.fb1 = FasterNetBlock(64)
self.fb2 = FasterNetBlock(64)
self.fb3 = FasterNetBlock(128)
self.fb4 = FasterNetBlock(256)
self.embed = PatchEmbed(input_channels, 64)
self.merge1 = PatchMerging(64, 64)
self.merge2 = PatchMerging(64, 128)
self.merge3 = PatchMerging(128, 256)
self.merge4 = PatchMerging(256, 512)
def forward(self, x: Tensor) -> Tensor:
features = []
features.append(self.embed(x)) # layer0
features.append(self.merge1(self.fb1(features[-1]))) # layer1
features.append(self.merge2(self.fb2(features[-1]))) # layer2
features.append(self.merge3(self.fb3(features[-1]))) # layer3
features.append(self.merge4(self.fb4(features[-1]))) # layer4
return features