NCNet
NCNet copied to clipboard
NCNet basenet PyTorch version
Hi Algolzw,
I'm trying to implement NCNet arch in PyTorch but it's not worked. Could you please take a look at my code?
class NCNet(nn.Module):
"""A compact network structure for super-resolution.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
num_feat (int): Channel number of intermediate features. Default: 64.
num_conv (int): Number of convolution layers in the body network. Default: 4.
scale (int): Upsampling factor. Default: 4.
"""
def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=4, scale=4):
super(NCNet, self).__init__()
self.num_in_ch = num_in_ch
self.num_out_ch = num_out_ch
self.num_feat = num_feat
self.num_conv = num_conv
self.scale = scale
self.body = nn.ModuleList()
self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1))
self.body.append(activation)
for _ in range(num_conv):
self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1))
self.body.append(nn.ReLU())
self.body.append(nn.Conv2d(num_feat, num_out_ch * scale * scale, 3, 1, 1))
self.body.append(nn.ReLU())
self.body.append(nn.Conv2d(num_out_ch * scale * scale, num_out_ch * scale * scale, 3, 1, 1))
self.upsampler = nn.PixelShuffle(scale)
self.weight = torch.from_numpy(
np.transpose(np.array([[[[1, 0, 0], [0, 1, 0], [0, 0, 1]] * scale * scale]]), (0, 1, 3, 2)))
self.res = nn.Conv2d(num_in_ch, num_out_ch * scale * scale, 1, bias=False)
with torch.no_grad():
self.res.weight.copy_(self.weight)
self._initialize_weights()
def _initialize_weights(self) -> None:
for module in self.body:
if isinstance(module, nn.Conv2d):
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
def forward(self, x):
out = x
for i in range(0, len(self.body)):
out = self.body[i](out)
out = out + self.res(x)
out = self.upsampler(out)
return torch.clip(out, 0., 255.)
I think the reason is the patterns of PixelShuffle in TensorFlow and PyTorch are different. In TensorFlow, you can just repeat the RGB channels scale*scale times (to RGBRGBRGB...), but in PyTorch you may need to repeat RGB separately like RRR...GGG...BBB...
Hi Algolzw,
I'm trying to implement NCNet arch in PyTorch but it's not worked. Could you please take a look at my code?
class NCNet(nn.Module): """A compact network structure for super-resolution. Args: num_in_ch (int): Channel number of inputs. Default: 3. num_out_ch (int): Channel number of outputs. Default: 3. num_feat (int): Channel number of intermediate features. Default: 64. num_conv (int): Number of convolution layers in the body network. Default: 4. scale (int): Upsampling factor. Default: 4. """ def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=4, scale=4): super(NCNet, self).__init__() self.num_in_ch = num_in_ch self.num_out_ch = num_out_ch self.num_feat = num_feat self.num_conv = num_conv self.scale = scale self.body = nn.ModuleList() self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) self.body.append(activation) for _ in range(num_conv): self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) self.body.append(nn.ReLU()) self.body.append(nn.Conv2d(num_feat, num_out_ch * scale * scale, 3, 1, 1)) self.body.append(nn.ReLU()) self.body.append(nn.Conv2d(num_out_ch * scale * scale, num_out_ch * scale * scale, 3, 1, 1)) self.upsampler = nn.PixelShuffle(scale) self.weight = torch.from_numpy( np.transpose(np.array([[[[1, 0, 0], [0, 1, 0], [0, 0, 1]] * scale * scale]]), (0, 1, 3, 2))) self.res = nn.Conv2d(num_in_ch, num_out_ch * scale * scale, 1, bias=False) with torch.no_grad(): self.res.weight.copy_(self.weight) self._initialize_weights() def _initialize_weights(self) -> None: for module in self.body: if isinstance(module, nn.Conv2d): nn.init.xavier_normal_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) def forward(self, x): out = x for i in range(0, len(self.body)): out = self.body[i](out) out = out + self.res(x) out = self.upsampler(out) return torch.clip(out, 0., 255.)
Hello, have you solved this problem yet? I am looking forward to pytorch version.
+1
@kelisiya @wxslby A simple implementation using PyTorch is now provided in this torch_code directory.