RMNet icon indicating copy to clipboard operation
RMNet copied to clipboard

Pruning

Open Cydia2018 opened this issue 3 years ago • 9 comments

作者可否将剪枝部分的代码也放出,对论文剪枝部分的内容和目前的代码还存在一些疑惑,感谢。

Cydia2018 avatar Dec 11 '21 11:12 Cydia2018

作者可否将剪枝部分的代码也放出,对论文剪枝部分的内容和目前的代码还存在一些疑惑,感谢。

之前没有打算放剪枝的代码,一方面是代码比较乱,另一反面觉得我的实现方式有点简单,RMNet剪枝应该有更好的表现,希望别人在我放出来的模型基础上能实现的更好。不过既然有需要,我还是把代码整理了出来: https://github.com/fxmeng/RMNet/blob/242f849c6e5e891646bbc90f89310268d183c310/train_pruning.py

fxmeng avatar Dec 24 '21 06:12 fxmeng

作者可否将剪枝部分的代码也放出,对论文剪枝部分的内容和目前的代码还存在一些疑惑,感谢。

之前没有打算放剪枝的代码,一方面是代码比较乱,另一反面觉得我的实现方式有点简单,RMNet剪枝应该有更好的表现,希望别人在我放出来的模型基础上能实现的更好。不过既然有需要,我还是把代码整理了出来: https://github.com/fxmeng/RMNet/blob/242f849c6e5e891646bbc90f89310268d183c310/train_pruning.py

感谢您的工作!

Cydia2018 avatar Dec 24 '21 06:12 Cydia2018

您好,我跑了剪枝训练的代码,发现不收敛。于是根据Network Slimming的思想做了一些改动

def update_mask(self,sr,threshold):
    for m in self.modules():
        if isinstance(m,nn.Conv2d):
            if m.kernel_size==(1,1) and m.groups!=1:
                m.weight.grad.data.add_(sr * torch.sign(m.weight.data))
                # m1 = m.weight.data.abs()>threshold
                # m.weight.grad.data*=m1
                # m.weight.data*=m1
def prune(self,use_bn=True,threshold=0.1):
    features=[]
    in_mask=torch.ones(3)>0
    blocks=self.deploy()
    for i,m in enumerate(blocks):
        if isinstance(m,nn.BatchNorm2d):
            mask=m.weight.data.abs().reshape(-1)>threshold
            ...

从头稀疏训练res18,lr=0.1,sr=1e-4,cifar10上的效果如下:

thresh params flops acc(%)
原模型 - 15.38M 803.75M 94.96
修剪之后 2e-3 4.06M 397.52M 94.81

Cydia2018 avatar Dec 29 '21 06:12 Cydia2018

如果不收敛,说明sr和threshold设置的太大了,建议调整一下这两个值再试试。另外也鼓励尝试更多剪枝方案,只需要注意,新增的通道也需要进行裁剪。

fxmeng avatar Dec 29 '21 06:12 fxmeng

您好!请问在训练/减枝/微调训练中的参数是怎么设置的?

The sparsity factor is selected from 1e-4 to 1e-3, and the threshold is selected from 5e-4 to 5e-3.

fxmeng avatar Jan 05 '22 02:01 fxmeng

谢谢大佬,我对模型训练和减枝训练存在一些疑惑,以下训练过程中的参数设置是否正确?感谢! 训练模型:python train_pruning.py --lr 0.1 --sr 1e-4 --threshold 2e-3 --finetune None --debn False --eval None 测试模型:python train_pruning.py --eval xxx/ckpt.pth 减枝训练:python train_pruning.py --lr 0.1 --sr 1e-4 --threshold 2e-3 --finetune xxx/ckpt.pth --debn False --eval None

Serissa avatar Jan 05 '22 03:01 Serissa

谢谢大佬,我对模型训练和减枝训练存在一些疑惑,以下训练过程中的参数设置是否正确?感谢! 训练模型:python train_pruning.py --lr 0.1 --sr 1e-4 --threshold 2e-3 --finetune None --debn False --eval None 测试模型:python train_pruning.py --eval xxx/ckpt.pth 减枝训练:python train_pruning.py --lr 0.1 --sr 1e-4 --threshold 2e-3 --finetune xxx/ckpt.pth --debn False --eval None

可以的,我当时调参的范围是: 1e-3 >=sr>=1e-4, 5e-3 >= threshold >= 5e-4, 你在这个范围内调一调,应该没问题,另外finetune的时候,需要减小learning rate,到大概0.01左右。

fxmeng avatar Jan 05 '22 04:01 fxmeng

您好,我把如下代码的改成RM形式,请问您帮我确认一下是否正确。谢谢! `class Bottleneck(nn.Module):

# Standard bottleneck
def __init__(self, c1, c2, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
    super(Bottleneck, self).__init__()
    c_ = int(c2 * e)  # hidden channels
    self.cv1 = nn.Conv2d(c1, c_, 1, 1)
    self.bn1 = nn.BatchNorm2d(c_)
    self.relu1 = nn.ReLU(inplace=True)
    self.cv2 = nn.Conv2d(c_, c2, 3, 1)
    self.bn2 = nn.BatchNorm2d(c_)
    self.relu2 = nn.ReLU(inplace=True)
def forward(self, x):
    return x +self.relu2(self.bn2( self.cv2(self.relu1(self.bn1(self.cv1(x))))))`

改成RM形式 `class Bottleneck(nn.Module):

# Standard bottleneck
def __init__(self, c1, c2, e=0.5): 
    super(Bottleneck, self).__init__()
    c_ = int(c2 * e)  # hidden channels
    self.in_planes1 = c1
    self.out_planes1 = c1 + c_
    self.in_planes2 = c1 + c_
    self.out_planes2 = c1 + c2
    self.out_planes = c2

    self.conv1 = nn.Conv2d(self.in_planes1, self.out_planes1-c1, kernel_size=1, stride=1, bias=False)
    self.bn1 = nn.BatchNorm2d(self.out_planes1-c1)
    self.mask1 = nn.Conv2d(self.out_planes1-c1, self.out_planes1-c1, 1, groups=self.out_planes1-c1, bias=False)
    self.relu1 = nn.ReLU(inplace=True)
    
    self.conv2 = nn.Conv2d(self.in_planes2-c1, self.out_planes2-c1, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(self.out_planes2-c1)
    self.mask2 = nn.Conv2d(self.out_planes2-c1, self.out_planes2-c1, 1, groups=self.out_planes2-c1, bias=False)
    self.relu2 = nn.ReLU(inplace=False)
    
        self.mask_res = nn.Sequential(
            *[nn.Conv2d(self.in_planes1, self.in_planes1, 1, groups=self.in_planes1, bias=False),
              nn.ReLU(inplace=False)])

        self.running1 = nn.BatchNorm2d(self.in_planes1, affine=False)
        self.running2 = nn.BatchNorm2d(self.out_planes, affine=False)

        nn.init.ones_(self.mask1.weight)
        nn.init.ones_(self.mask2.weight)
        nn.init.ones_(self.mask_res[0].weight)

def forward(self, x):
    self.running1(x)
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.mask1(out)
    out = self.relu1(out)

    out = self.conv2(out)
    out = self.bn2(out)
    out = self.mask2(out)
    out = self.relu2(out)
    out += self.mask_res(x)
    self.running2(out)
    return out

def deploy(self, merge_bn=False):
    idconv1 = nn.Conv2d(self.in_planes1, self.out_planes1, kernel_size=self.kernel1, stride=self.stride1, padding=1,
                        bias=False).eval()
    idbn1 = nn.BatchNorm2d(self.out_planes1).eval()
    # init dirac_ kernel weight, bias, mean var to idconv1
    nn.init.dirac_(idconv1.weight.data[:self.in_planes1])
    bn_var_sqrt = torch.sqrt(self.running1.running_var + self.running1.eps)
    idbn1.weight.data[:self.in_planes1] = bn_var_sqrt
    idbn1.bias.data[:self.in_planes1] = self.running1.running_mean
    idbn1.running_mean.data[:self.in_planes1] = self.running1.running_mean
    idbn1.running_var.data[:self.in_planes1] = self.running1.running_var
    # init conv1 to idconv1
    idconv1.weight.data[self.in_planes1:] = self.conv1.weight.data
    idbn1.weight.data[self.in_planes1:] = self.bn1.weight.data
    idbn1.bias.data[self.in_planes1:] = self.bn1.bias.data
    idbn1.running_mean.data[self.in_planes1:] = self.bn1.running_mean
    idbn1.running_var.data[self.in_planes1:] = self.bn1.running_var
    # init mask_res mask to mask1
    mask1 = nn.Conv2d(self.out_planes1, self.out_planes1, 1, groups=self.out_planes1, bias=False)
    mask1.weight.data[:self.in_planes1] = self.mask_res[0].weight.data*(self.mask_res[0].weight.data > 0)
    mask1.weight.data[self.in_planes1:] = self.mask1.weight.data
    idbn1.weight.data *= mask1.weight.data.reshape(-1)
    idbn1.bias.data *= mask1.weight.data.reshape(-1)

    # conv2
    idconv2 = nn.Conv2d(self.in_planes2, self.out_planes2, kernel_size=self.kernel2, stride=self.stride2, padding=1,
                        bias=False).eval()
    idbn2 = nn.BatchNorm2d(self.out_planes2).eval()
    # init dirac_ kernel weight, bias, mean var to idconv1
    nn.init.dirac_(idconv2.weight.data[:self.in_planes1])
    bn_var_sqrt = torch.sqrt(self.running1.running_var + self.running1.eps)
    idbn2.weight.data[:self.in_planes1] = bn_var_sqrt
    idbn2.bias.data[:self.in_planes1] = self.running1.running_mean
    idbn2.running_mean.data[:self.in_planes1] = self.running1.running_mean
    idbn2.running_var.data[:self.in_planes1] = self.running1.running_var
    # init conv2 to idconv2
    idconv2.weight.data[self.in_planes1:, self.in_planes1:, :, :] = self.conv2.weight.data
    idbn2.weight.data[self.in_planes1:] = self.bn2.weight.data
    idbn2.bias.data[self.in_planes1:] = self.bn2.bias.data
    idbn2.running_mean.data[self.in_planes1:] = self.bn2.running_mean
    idbn2.running_var.data[self.in_planes1:] = self.bn2.running_var
    # init mask_res mask to mask2
    mask2 = nn.Conv2d(self.out_planes2, self.out_planes2, 1, groups=self.out_planes2, bias=False)
    mask2.weight.data[:self.in_planes1] = self.mask_res[0].weight.data*(self.mask_res[0].weight.data > 0)
    mask2.weight.data[self.in_planes1:] = self.mask2.weight.data
    idbn2.weight.data *= mask2.weight.data.reshape(-1)
    idbn2.bias.data *= mask2.weight.data.reshape(-1)

    # init idconv3
    idconv3 = nn.Conv2d(self.out_planes2, self.out_planes, kernel_size=1, stride=1, padding=0, bias=False).eval()
    idbn3 = nn.BatchNorm2d(self.out_planes).eval()
    nn.init.dirac_(idconv3.weight.data[:, :self.in_planes1])
    nn.init.dirac_(idconv3.weight.data[:, self.in_planes1:])
    bn_var_sqrt = torch.sqrt(self.running2.running_var + self.running2.eps)
    idbn3.weight.data = bn_var_sqrt
    idbn3.bias.data = self.running2.running_mean
    idbn3.running_mean.data = self.running2.running_mean
    idbn3.running_var.data = self.running2.running_var
    # init mask_res mask to mask2
    mask3 = nn.Conv2d(self.out_planes, self.out_planes, 1, groups=self.out_planes, bias=False)
    mask3.weight.data = self.mask_res[0].weight.data
    idbn3.weight.data *= mask3.weight.data.reshape(-1)
    idbn3.bias.data *= mask3.weight.data.reshape(-1)

    return [idconv1, idbn1, nn.ReLU(True), idconv2, idbn2, nn.ReLU(True), idconv3, idbn3]`

只要对比下变化之前和变化之后的值是不是相等就可以,注意需要在eval()模式下对比。

fxmeng avatar Jan 19 '22 12:01 fxmeng

您好!我把代码修改了一下,在eval()模式下对比变化之前和变化之后的输出值不相等。我调试了一天不知道问题出在哪里,能帮我看看代码是哪里写错了吗?

Serissa avatar Jan 20 '22 08:01 Serissa