gate-decorator-pruning icon indicating copy to clipboard operation
gate-decorator-pruning copied to clipboard

convolution layer with cardinality

Open OValery16 opened this issue 4 years ago • 4 comments

Dear author,

Thank for this impressive piece of work. How to implement convolution layer with cardinality ?

in universal.py we have:

def melt(self):
        if self.conv.groups == 1:
            groups = 1
        elif self.conv.groups == self.conv.out_channels:
            groups = int((self.out_mask != 0).sum())
        else:
            assert False

        replacer = nn.Conv2d(
            in_channels = int((self.in_mask != 0).sum()),
            out_channels = int((self.out_mask != 0).sum()),
            kernel_size = self.conv.kernel_size,
            stride = self.conv.stride,
            padding = self.conv.padding,
            dilation = self.conv.dilation,
            groups = groups,
            bias = (self.conv.bias is not None)
        ).to(self.conv.weight.device)

        with torch.no_grad():
            if self.conv.groups == 1:
                replacer.weight.set_(self.conv.weight[self.out_mask != 0][:, self.in_mask != 0])
            else:
                replacer.weight.set_(self.conv.weight[self.out_mask != 0])
            if self.conv.bias is not None:
                replacer.bias.set_(self.conv.bias[self.out_mask != 0])
        return replacer

if the convolution layer have a cardinality (like in several modern model), we get assert False. How to implement cardinality for convolution layers ?

OValery16 avatar Jun 08 '20 13:06 OValery16

Removing a single filter in the group convolution in PyTorch will cause misalignment, but it's possible to remove the entire group by using the Group Pruning method proposed in our paper.

youzhonghui avatar Jun 09 '20 07:06 youzhonghui

Thanks a lot for your quick reply.

I understand that using Group Pruning method allows us to synchronize the number of channel that are pruned within a group in order to avoid misalignment. However, I am still confused how to make sure the final input number of channel in a given Group Pruning is dividable by the number of group (in the convolution)

Should I modify g.minimal_filter such as g.minimal_filter=int(g.minimal_filter//conv_groups) ?

OValery16 avatar Jun 09 '20 08:06 OValery16

In fact, I am a bit confused how to remove an entire group.

OValery16 avatar Jun 09 '20 10:06 OValery16

If you remove a whole group of filters at a time, the number can be kept divisible. The troublesome thing is to maintain the number of channels of the input feature map, since group convolution also divides input. For example, given 6 filters that are divided into 3 groups, and after removing the first group of convolutions, there are 4 filters left and the cardinality should change to 2. The input channel may still be 6, and two corresponding feature maps need to be discarded. So the code to prune a network with cardinality may be very different from what we provided.

youzhonghui avatar Jun 10 '20 10:06 youzhonghui