U-2-Net
U-2-Net copied to clipboard
about U2Net 安定
A wonderful project! but I have some problems on the training : Why can reducing the channel value reduce the model size? I am confused about the channel and I'm a beginner here, could you tell me some more details? Furthermore, I wonder if there is the model between u2net and u2netp? Is it feasible to choose a medium value of mid_channel and out_channel to reach an better performance than u2netp but smaller than u2net? Thanks a lot!
class U2NET_DIY_26MB(nn.Module):
def __init__(self,in_ch=3,out_ch=1):
super(U2NET_DIY_26MB,self).__init__()
#stage-n的输出要是下一个stage的输入的尺度
self.stage1 = RSU7(in_ch,16,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage2 = RSU6(64,16,64)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage3 = RSU5(64,32,96)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage4 = RSU4(96,48,128)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage5 = RSU4F(128,64,160)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
self.stage6 = RSU4F(160,96,192)
# decoder
self.stage5d = RSU4F(160+192,128,160) #输入维度是stage5+stage6的输出的维度
self.stage4d = RSU4(128+160,64,128) #输入维度是stage5d+stage4的输出的维度
self.stage3d = RSU5(96+128,48,96) #输入维度是stage4d+stage3的输出的维度
self.stage2d = RSU6(64+96,32,64)
self.stage1d = RSU7(64+64,16,48)
self.side1 = nn.Conv2d(48,out_ch,3,padding=1)
self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
self.side3 = nn.Conv2d(96,out_ch,3,padding=1)
self.side4 = nn.Conv2d(128,out_ch,3,padding=1)
self.side5 = nn.Conv2d(160,out_ch,3,padding=1) #side5对应stage5d的输出
self.side6 = nn.Conv2d(192,out_ch,3,padding=1) #side6对应stage6的输出 maps Sside(6)~Sside(1) from stages En6, De5~De1 a 3 × 3 convolution layer and a sigmoid function.
self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
def forward(self,x):
hx = x
#stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
#stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6,hx5)
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)
hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)
hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)
hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)
hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))
#side output
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2,d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3,d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4,d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5,d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6,d1)
d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)