pytorch-prunes icon indicating copy to clipboard operation
pytorch-prunes copied to clipboard

RuntimeError: running_mean should contain 57 elements not 64

Open cvJie opened this issue 5 years ago • 2 comments

hi, I success to pruned and finetune the cifar res_model, have successed to finish the pruned model by own data and model, but and now need to finetune the pruned model ,happened some error:

raceback (most recent call last): File "train_fashion.py", line 155, in log_dict_train=train(epoch,train_loader) File "train_fashion.py", line 117, in train log_dict_train, _ = trainer.train(epoch, train_loader) File "/fashiontrain/fashionprunes/lib/trains/base_trainer.py", line 103, in train return self.run_epoch('train', epoch, data_loader) File "/fashiontrain/fashionprunes/lib/trains/base_trainer.py", line 79, in run_epoch output, loss, loss_stats = model_with_loss(batch) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/fashiontrain/fashionprunes/lib/trains/base_trainer.py", line 19, in forward outputs = self.model(batch['input']) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/fashiontrain/fashionprunes/lib/models/networks/msra_resnet.py", line 372, in forward x = self.layer1(x) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py", line 92, in forward input = module(input) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/fashiontrain/fashionprunes/lib/models/networks/msra_resnet.py", line 137, in forward out = self.bn1(out) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 493, in call result = self.forward(*input, **kwargs) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/batchnorm.py", line 83, in forward exponential_average_factor, self.eps) File "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py", line 1697, in batch_norm training, momentum, eps, torch.backends.cudnn.enabled RuntimeError: running_mean should contain 57 elements not 64

can you give me some advices?? thank you

cvJie avatar Nov 13 '19 03:11 cvJie

When you prune a convolutional layer, you also need to prune the following batchnorm layer.

You can see an example of this here: https://github.com/BayesWatch/pytorch-prunes/blob/bc85a5c52865a2daf515ad4d3c26dcab88e3d941/models/wideresnet.py#L205

jack-willturner avatar Nov 13 '19 11:11 jack-willturner

When you prune a convolutional layer, you also need to prune the following batchnorm layer.

You can see an example of this here:

https://github.com/BayesWatch/pytorch-prunes/blob/bc85a5c52865a2daf515ad4d3c26dcab88e3d941/models/wideresnet.py#L205

------------------------------------------------------------------------------

Copyright (c) Microsoft

Licensed under the MIT License.

Written by Bin Xiao ([email protected])

Modified by Xingyi Zhou

------------------------------------------------------------------------------

from future import absolute_import from future import division from future import print_function import torch.nn.functional as F import os

import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo all = ['get_pose_net'] BN_MOMENTUM = 0.1

def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class BasicBlock(nn.Module): expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None):
    super(BasicBlock, self).__init__()
    self.conv1 = conv3x3(inplanes, planes, stride)
    self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
    self.relu = nn.ReLU(inplace=True)

    self.conv2 = conv3x3(planes, planes)
    self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
    self.downsample = downsample
    self.stride = stride

def forward(self, x):
    residual = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)

    if self.downsample is not None:
        residual = self.downsample(x)

    out += residual
    out = self.relu(out)

    return out

class Identity(nn.Module): def init(self): super(Identity, self).init()

def forward(self, x):
    return x

class Zero(nn.Module): def init(self): super(Zero, self).init()

def forward(self, x):
    return x * 0

class ZeroMake(nn.Module): def init(self, channels, spatial): super(ZeroMake, self).init() self.spatial = spatial self.channels = channels

def forward(self, x):
    return torch.zeros([x.size()[0], self.channels, x.size()[2] // self.spatial, x.size()[3] // self.spatial],
                       dtype=x.dtype, layout=x.layout, device=x.device)

class MaskBlock(nn.Module): expansion = 1

def __init__(self, inplanes, planes, stride=1, downsample=None):

    super(MaskBlock, self).__init__()
    self.conv1 = conv3x3(inplanes, planes, stride)
    #self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)

    self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
    self.relu = nn.ReLU(inplace=True)

    self.conv2 = conv3x3(planes, planes)
    self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
    self.downsample = downsample
    self.stride = stride

    self.equalInOut = (inplanes == planes)

    self.activation = Identity()
    self.activation.register_backward_hook(self._fisher)
    self.register_buffer('mask', None)

    self.input_shape = None
    self.output_shape = None
    self.flops = None
    self.params = None
    self.in_channels = inplanes
    self.out_channels = planes
    self.got_shapes = False

    # Fisher method is called on backward passes
    self.running_fisher = 0

def forward(self, x):

    residual = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)


    if self.mask is not None:
        out = out * self.mask[None, :, None, None]
    else:
        self._create_mask(x, out)

    out = self.activation(out)
    self.act = out

    out = self.conv2(out)
    out = self.bn2(out)

    if self.downsample is not None:
        residual = self.downsample(x)

    out += residual
    out = self.relu(out)


    return out

def _create_mask(self, x, out):

    self.mask = x.new_ones(out.shape[1])
    self.input_shape = x.size()
    self.output_shape = out.size()

def _fisher(self, _, notused2, grad_output):
    act = self.act.detach()
    grad = grad_output[0].detach()

    g_nk = (act * grad).sum(-1).sum(-1)
    del_k = g_nk.pow(2).mean(0).mul(0.5)
    self.running_fisher += del_k

def reset_fisher(self):
    self.running_fisher = 0 * self.running_fisher

def cost(self):

    in_channels = self.in_channels
    out_channels = self.out_channels
    middle_channels = int(self.mask.sum().item())

    conv1_size = self.conv1.weight.size()
    conv2_size = self.conv2.weight.size()

    # convs
    self.params = in_channels * middle_channels * conv1_size[2] * conv1_size[3] + middle_channels * out_channels * \
                  conv2_size[2] * conv2_size[3]

    # batchnorms, assuming running stats are absorbed
    self.params += 2 * in_channels + 2 * middle_channels

    # skip
    if not self.equalInOut:
        self.params += in_channels * out_channels
    else:
        self.params += 0

def compress_weights(self):

    middle_dim = int(self.mask.sum().item())
    print(middle_dim)

    if middle_dim is not 0:
        conv1 = nn.Conv2d(self.in_channels, middle_dim, kernel_size=3, stride=self.stride, padding=1, bias=False)
        conv1.weight = nn.Parameter(self.conv1.weight[self.mask == 1, :, :, :])

        # Batch norm 2 changes
        bn2 = nn.BatchNorm2d(middle_dim)
        bn2.weight = nn.Parameter(self.bn2.weight[self.mask == 1])
        bn2.bias = nn.Parameter(self.bn2.bias[self.mask == 1])
        bn2.running_mean = self.bn2.running_mean[self.mask == 1]
        bn2.running_var = self.bn2.running_var[self.mask == 1]

        conv2 = nn.Conv2d(middle_dim, self.out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        conv2.weight = nn.Parameter(self.conv2.weight[:, self.mask == 1, :, :])

    if middle_dim is 0:
        conv1 = Zero()
        bn2 = Zero()
        conv2 = ZeroMake(channels=self.out_channels, spatial=self.stride)

    self.conv1 = conv1
    self.conv2 = conv2
    self.bn2 = bn2

    if middle_dim is not 0:
        self.mask = torch.ones(middle_dim)
    else:
        self.mask = torch.ones(1)

class PoseResNet(nn.Module):

def __init__(self, block, layers, heads, head_conv, **kwargs):
    self.inplanes = 64
    self.deconv_with_bias = False
    self.heads = heads

    super(PoseResNet, self).__init__()
    self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                           bias=False)
    self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
    self.relu = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    self.layer1 = self._make_layer(block, 64, layers[0])
    self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
    self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
    self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

    # used for deconv layers
    self.deconv_layers = self._make_deconv_layer(
        3,
        [256, 256, 256],
        [4, 4, 4],
    )
    for head in sorted(self.heads):
      num_output = self.heads[head]
      if head_conv > 0:
        fc = nn.Sequential(
            nn.Conv2d(256, head_conv,
              kernel_size=3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(head_conv, num_output,
              kernel_size=1, stride=1, padding=0))
      else:
        fc = nn.Conv2d(
          in_channels=256,
          out_channels=num_output,
          kernel_size=1,
          stride=1,
          padding=0
      )
      self.__setattr__(head, fc)

    # Count params that don't exist in blocks (conv1, bn1, fc)
    self.fixed_params = len(self.conv1.weight.view(-1)) + len(self.bn1.weight) + len(self.bn1.bias)

def _make_layer(self, block, planes, blocks, stride=1):
    downsample = None
    if stride != 1 or self.inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            nn.Conv2d(self.inplanes, planes * block.expansion,
                      kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
        )

    layers = []
    layers.append(block(self.inplanes, planes, stride, downsample))
    self.inplanes = planes * block.expansion
    for i in range(1, blocks):
        layers.append(block(self.inplanes, planes))

    return nn.Sequential(*layers)

def _get_deconv_cfg(self, deconv_kernel, index):
    if deconv_kernel == 4:
        padding = 1
        output_padding = 0
    elif deconv_kernel == 3:
        padding = 1
        output_padding = 1
    elif deconv_kernel == 2:
        padding = 0
        output_padding = 0

    return deconv_kernel, padding, output_padding

def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
    assert num_layers == len(num_filters), \
        'ERROR: num_deconv_layers is different len(num_deconv_filters)'
    assert num_layers == len(num_kernels), \
        'ERROR: num_deconv_layers is different len(num_deconv_filters)'

    layers = []
    for i in range(num_layers):
        kernel, padding, output_padding = \
            self._get_deconv_cfg(num_kernels[i], i)

        planes = num_filters[i]
        layers.append(
            nn.ConvTranspose2d(
                in_channels=self.inplanes,
                out_channels=planes,
                kernel_size=kernel,
                stride=2,
                padding=padding,
                output_padding=output_padding,
                bias=self.deconv_with_bias))

        layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
        layers.append(nn.ReLU(inplace=True))
        self.inplanes = planes

    return nn.Sequential(*layers)


def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.deconv_layers(x)

    #print("x:",x)
    ret = {}

    #print("selfheads:",self.heads){'hm': 13, 'wh': 2, 'reg': 2}
    for head in self.heads:
       ret[head] = self.__getattr__(head)(x)


    return [ret]

def init_weights(self, num_layers, pretrained=True):
    if pretrained:
        # print('=> init resnet deconv weights from normal distribution')
        for _, m in self.deconv_layers.named_modules():
            if isinstance(m, nn.ConvTranspose2d):
                # print('=> init {}.weight as normal(0, 0.001)'.format(name))
                # print('=> init {}.bias as 0'.format(name))
                nn.init.normal_(m.weight, std=0.001)
                if self.deconv_with_bias:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                # print('=> init {}.weight as 1'.format(name))
                # print('=> init {}.bias as 0'.format(name))
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        # print('=> init final conv weights from normal distribution')

        for head in self.heads:
          final_layer = self.__getattr__(head)
          #print("*************final_layer:",final_layer)
          for i, m in enumerate(final_layer.modules()):
              #print("m:",m)
              if isinstance(m, nn.Conv2d):
                  # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                  # print('=> init {}.weight as normal(0, 0.001)'.format(name))
                  # print('=> init {}.bias as 0'.format(name))
                  if m.weight.shape[0] == self.heads[head]:
                      if 'hm' in head:
                          nn.init.constant_(m.bias, -2.19)
                      else:
                          nn.init.normal_(m.weight, std=0.001)
                          nn.init.constant_(m.bias, 0)

    else:
        print('=> imagenet pretrained model dose not exist')
        print('=> please download it first')
        raise ValueError('imagenet pretrained model does not exist')

def get_pose_net(num_layers, heads, head_conv,mask=False): #block_class, layers = resnet_spec[num_layers] print("heads:",heads)

if mask==1:
    model = PoseResNet(MaskBlock, [2, 2, 2, 2], heads, head_conv=head_conv)
    # for name, parameters in model.named_parameters():
    #     print(name, '            \t:', parameters.size())
else:
    model = PoseResNet(BasicBlock, [2,2,2,2], heads, head_conv=head_conv)
    #print("model:", model)

model.init_weights(num_layers= num_layers,pretrained=True)
return model

this is my code ,can you give me some advices?? thanks

cvJie avatar Nov 14 '19 02:11 cvJie