U-2-Net icon indicating copy to clipboard operation
U-2-Net copied to clipboard

about U2Net 安定

Open Breezewrf opened this issue 3 years ago • 1 comments

A wonderful project! but I have some problems on the training : Why can reducing the channel value reduce the model size? I am confused about the channel and I'm a beginner here, could you tell me some more details? Furthermore, I wonder if there is the model between u2net and u2netp? Is it feasible to choose a medium value of mid_channel and out_channel to reach an better performance than u2netp but smaller than u2net? Thanks a lot!

Breezewrf avatar Apr 14 '21 04:04 Breezewrf

class U2NET_DIY_26MB(nn.Module):

def __init__(self,in_ch=3,out_ch=1):
    super(U2NET_DIY_26MB,self).__init__()

    #stage-n的输出要是下一个stage的输入的尺度
    self.stage1 = RSU7(in_ch,16,64)
    self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

    self.stage2 = RSU6(64,16,64)
    self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

    self.stage3 = RSU5(64,32,96)    
    self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

    self.stage4 = RSU4(96,48,128)
    self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

    self.stage5 = RSU4F(128,64,160)
    self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)
    self.stage6 = RSU4F(160,96,192)

    # decoder
    self.stage5d = RSU4F(160+192,128,160) #输入维度是stage5+stage6的输出的维度
    self.stage4d = RSU4(128+160,64,128)   #输入维度是stage5d+stage4的输出的维度
    self.stage3d = RSU5(96+128,48,96)    #输入维度是stage4d+stage3的输出的维度
    self.stage2d = RSU6(64+96,32,64)
    self.stage1d = RSU7(64+64,16,48)

    self.side1 = nn.Conv2d(48,out_ch,3,padding=1)
    self.side2 = nn.Conv2d(64,out_ch,3,padding=1)
    self.side3 = nn.Conv2d(96,out_ch,3,padding=1)
    self.side4 = nn.Conv2d(128,out_ch,3,padding=1)
    self.side5 = nn.Conv2d(160,out_ch,3,padding=1) #side5对应stage5d的输出
    self.side6 = nn.Conv2d(192,out_ch,3,padding=1) #side6对应stage6的输出    maps Sside(6)~Sside(1) from stages En6, De5~De1 a 3 × 3 convolution layer and a sigmoid function.

    self.outconv = nn.Conv2d(6*out_ch,out_ch,1)

def forward(self,x):

    hx = x

    #stage 1
    hx1 = self.stage1(hx)
    hx = self.pool12(hx1)

    #stage 2
    hx2 = self.stage2(hx)
    hx = self.pool23(hx2)

    #stage 3
    hx3 = self.stage3(hx)
    hx = self.pool34(hx3)

    #stage 4
    hx4 = self.stage4(hx)
    hx = self.pool45(hx4)

    #stage 5
    hx5 = self.stage5(hx)
    hx = self.pool56(hx5)

    #stage 6
    hx6 = self.stage6(hx)
    hx6up = _upsample_like(hx6,hx5)

    hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
    hx5dup = _upsample_like(hx5d,hx4)

    hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
    hx4dup = _upsample_like(hx4d,hx3)  

    hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
    hx3dup = _upsample_like(hx3d,hx2)

    hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
    hx2dup = _upsample_like(hx2d,hx1)

    hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))


    #side output
    d1 = self.side1(hx1d)

    d2 = self.side2(hx2d)
    d2 = _upsample_like(d2,d1)

    d3 = self.side3(hx3d)
    d3 = _upsample_like(d3,d1)

    d4 = self.side4(hx4d)
    d4 = _upsample_like(d4,d1)

    d5 = self.side5(hx5d)
    d5 = _upsample_like(d5,d1)

    d6 = self.side6(hx6)
    d6 = _upsample_like(d6,d1)

    d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))

    return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)

piaobuliao avatar May 05 '22 08:05 piaobuliao