GuideNet icon indicating copy to clipboard operation
GuideNet copied to clipboard

A naive GuidedConv implementation

Open AlanDecode opened this issue 5 years ago • 4 comments
trafficstars

I guess the source code for this repo would never be released, and since people are confused by "Guided Conv Module" in the paper, I'm going to share my naive PyTorch implementation based on CSPN code.

Note: This implementation was based on my own understanding of the paper, I'm not sure if it's correct. Besides, I got worse result with this module compared to ordinary concat/add fusion method.

Let me know if you have any questions or suggestions.

class _GuidedConv(nn.Module):
    def __init__(self):
        super(_GuidedConv, self).__init__()

        self.pad_left_top = nn.ZeroPad2d((1, 0, 1, 0))
        self.pad_center_top = nn.ZeroPad2d((0, 0, 1, 0))
        self.pad_right_top = nn.ZeroPad2d((0, 1, 1, 0))
        self.pad_left_middle = nn.ZeroPad2d((1, 0, 0, 0))
        self.pad_right_middel = nn.ZeroPad2d((0, 1, 0, 0))
        self.pad_left_bottom = nn.ZeroPad2d((1, 0, 0, 1))
        self.pad_center_bottom = nn.ZeroPad2d((0, 0, 0, 1))
        self.pad_right_bottom = nn.ZeroPad2d((0, 1, 0, 1))

    def forward(self, x, cw: list, cc: list):
        """
        `x`: input feature maps with size `[B, C_in, H, W]`  
        `cw`: `C_in` channel-wise kernels, each with size [B, 3*3, H, W]  
        `cc`: `C_out` cross-channel 1*1 kernels, each with size [B, C_in]
        """

        # stage-1: weight x with kernels in `cw`
        tmp = []
        for i in range(len(cw)):
            feat = self._compose_feat(x[:, i, :, :].unsqueeze_(1))
            feat *= cw[i]
            tmp.append(torch.sum(feat, dim=1, keepdim=True))
        tmp = torch.cat(tmp, dim=1)  # [B, C_in, H, W]

        # stage-2: weight tmp with kernels in `cc`
        out = []
        for i in range(len(cc)):
            weight = cc[i].unsqueeze_(-1).unsqueeze_(-1)
            out.append(torch.sum(tmp * weight, dim=1, keepdim=True))

        return torch.cat(out, dim=1)  # [B, C_out, H, W]

    def _compose_feat(self, feat: torch.FloatTensor):
        [H, W] = feat.shape[2:]
        output = [feat]

        # left-top
        output.append(self.pad_left_top(feat)[:, :, :H, :W])
        # center-top
        output.append(self.pad_center_top(feat)[:, :, :H, :])
        # right-top
        output.append(self.pad_right_top(feat)[:, :, :H, 1:])
        # left-middle
        output.append(self.pad_left_middle(feat)[:, :, :, :W])
        # right-middle
        output.append(self.pad_right_middel(feat)[:, :, :, 1:])
        # left-bottom
        output.append(self.pad_left_bottom(feat)[:, :, 1:, :W])
        # center-bottom
        output.append(self.pad_center_bottom(feat)[:, :, 1:, :])
        # right-bottom
        output.append(self.pad_right_bottom(feat)[:, :, 1:, 1:])
        # concat
        output = torch.cat(output, dim=1)  # [B, 3*3, H, W]

        return output

AlanDecode avatar Jun 07 '20 10:06 AlanDecode

Worse results as well. But CSPN works as a refinement module.

JUGGHM avatar Jun 07 '20 13:06 JUGGHM

I think it is really hard to re-implement guided conv with info from paper only. Too many details are missing. For example, image We don't know the params of this 'standard convolution layer', and whether it is followed by BN, activation, etc.

godspeed1989 avatar Jul 12 '20 15:07 godspeed1989

Worse results as well. But CSPN works as a refinement module.

is there an open source version of CSPN?

mrgransky avatar Jul 29 '20 09:07 mrgransky

Worse results as well. But CSPN works as a refinement module.

is there an open source version of CSPN?

https://github.com/XinJCheng/CSPN

JUGGHM avatar Jul 29 '20 09:07 JUGGHM