GuideNet icon indicating copy to clipboard operation
GuideNet copied to clipboard

Help for my GuidedConv Implementation

Open godspeed1989 opened this issue 4 years ago • 2 comments

Deat author, I try to reimplement your GuidedConv based on your paper. But I found the training process is very unstable. Here is my pytorch implementation:

class _GuideConv(nn.Module):
    def __init__(self, img_in, sparse_in, sparse_out, K=3):
        super(_GuideConv, self).__init__()
        # KGL: kernel generating layer
        self.conv1 = nn.Conv2d(img_in, K*K*sparse_in, kernel_size=3,
                               stride=1, padding=1, groups=1, bias=False)
        self.bn1 =  nn.BatchNorm2d(K*K*sparse_in)

        self.fc = nn.Linear(img_in, sparse_in * sparse_out, bias=False)
        self.bn2 = nn.BatchNorm2d(sparse_out)

        self.img_in = img_in
        self.sparse_in = sparse_in
        self.sparse_out = sparse_out

    def forward(self, G, S):
        '''
        G: input guidance
        S: input source feature
        '''
        # spatially-variant
        W1 = self.conv1(G) # [B,Cin*K*K,H,W]
        W1 = self.bn1(W1)
        depths = torch.chunk(S, self.sparse_in, 1) # Cin*[B,1,H,W]
        kernels = torch.chunk(W1, self.sparse_in, 1) # Cin*[B,K*K,H,W]
        S1 = []
        for i in range(self.sparse_in):
            S1.append(torch.sum(depths[i]*kernels[i], 1, keepdim=True))
        S1 = torch.cat(S1, 1) # [B,Cin,H,W]

        # cross-channel conv
        W2 = F.adaptive_avg_pool2d(G, (1, 1))
        B = W2.size(0)
        W2 = W2.reshape(B, -1) # (b,img_in)
        W2 = self.fc(W2)
        W2 = W2.view([B, self.sparse_out, self.sparse_in])

        depths = torch.chunk(S1, B, 0) # B*[1,Cin,H,W]
        kernels = torch.chunk(W2, B, 0) # B*[1,Cout,Cin]
        S2 = []
        for i in range(B):
            weight = kernels[i][0].unsqueeze(-1).unsqueeze(-1) # [Cout,Cin,1,1]
            S2.append(F.conv2d(depths[i], weight, bias=None, stride=1, padding=0))
        S2 = torch.cat(S2, 0) # [B,Cin,H,W]
        S2 = F.relu(self.bn2(S2))

        return S2

Could you give me some advice? or just open source this part of code? :tada:

godspeed1989 avatar Jul 24 '20 04:07 godspeed1989

Thanks for your interest. As the paper is still in peer review and I'm working hard on a new project these days, the code has to be cleaned and released in the further. (Sorry for this.) I take a glance on your implementation, S1.append(torch.sum(depths[i]*kernels[i], 1, keepdim=True)) The receptive field of this convolution operation should be K, but in your implementation it's 1?

kakaxi314 avatar Jul 26 '20 03:07 kakaxi314

Thanks for replying. You mean I should increase the receptive field by expanding (or reorganizing) depths[i] from [1,H,W] to [K*K,H,W]? Are there any other unreasonable places in my implementation? Thank you again.

godspeed1989 avatar Jul 26 '20 14:07 godspeed1989