WeightNet icon indicating copy to clipboard operation
WeightNet copied to clipboard

in the WeightNet_DW,is that "x_w = x_w.reshape(-1, 1, 1, self.ksize, self.ksize)"wrong?why the input and output channel is 1?

Open kairenchen123 opened this issue 1 year ago • 0 comments

class WeightNet_DW(M.Module): r""" Here we show a grouping manner when we apply WeightNet to a depthwise convolution.

The grouped fc layer directly generates the convolutional kernel, has fewer parameters while achieving comparable results.
This layer has M/G*inp inputs, inp groups and inp*ksize*ksize outputs.

"""
def __init__(self, inp, ksize, stride):
    super().__init__()

    self.M = 2
    self.G = 2

    self.pad = ksize // 2
    inp_gap = max(16, inp//16)
    self.inp = inp
    self.ksize = ksize
    self.stride = stride

    self.wn_fc1 = M.Conv2d(inp_gap, self.M//self.G*inp, 1, 1, 0, groups=1, bias=True)
    self.sigmoid = M.Sigmoid()
    self.wn_fc2 = M.Conv2d(self.M//self.G*inp, inp*ksize*ksize, 1, 1, 0, groups=inp, bias=False)

# x_gap是经过AGP的值
def forward(self, x, x_gap):
    x_w = self.wn_fc1(x_gap)
    x_w = self.sigmoid(x_w)
    x_w = self.wn_fc2(x_w)

    x = x.reshape(1, -1, x.shape[2], x.shape[3])
    x_w = x_w.reshape(-1, 1, 1, self.ksize, self.ksize)
    x = F.conv2d(x, weight=x_w, stride=self.stride, padding=self.pad, groups=x_w.shape[0])
    x = x.reshape(-1, self.inp, x.shape[2], x.shape[3])
    return x

kairenchen123 avatar May 21 '23 08:05 kairenchen123