pytorch-prunes
pytorch-prunes copied to clipboard
RuntimeError: running_mean should contain 57 elements not 64
hi, I success to pruned and finetune the cifar res_model, have successed to finish the pruned model by own data and model, but and now need to finetune the pruned model ,happened some error:
raceback (most recent call last):
File "train_fashion.py", line 155, in
can you give me some advices?? thank you
When you prune a convolutional layer, you also need to prune the following batchnorm layer.
You can see an example of this here: https://github.com/BayesWatch/pytorch-prunes/blob/bc85a5c52865a2daf515ad4d3c26dcab88e3d941/models/wideresnet.py#L205
When you prune a convolutional layer, you also need to prune the following batchnorm layer.
You can see an example of this here:
https://github.com/BayesWatch/pytorch-prunes/blob/bc85a5c52865a2daf515ad4d3c26dcab88e3d941/models/wideresnet.py#L205
------------------------------------------------------------------------------
Copyright (c) Microsoft
Licensed under the MIT License.
Written by Bin Xiao ([email protected])
Modified by Xingyi Zhou
------------------------------------------------------------------------------
from future import absolute_import from future import division from future import print_function import torch.nn.functional as F import os
import torch import torch.nn as nn import torch.utils.model_zoo as model_zoo all = ['get_pose_net'] BN_MOMENTUM = 0.1
def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
class BasicBlock(nn.Module): expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Identity(nn.Module): def init(self): super(Identity, self).init()
def forward(self, x):
return x
class Zero(nn.Module): def init(self): super(Zero, self).init()
def forward(self, x):
return x * 0
class ZeroMake(nn.Module): def init(self, channels, spatial): super(ZeroMake, self).init() self.spatial = spatial self.channels = channels
def forward(self, x):
return torch.zeros([x.size()[0], self.channels, x.size()[2] // self.spatial, x.size()[3] // self.spatial],
dtype=x.dtype, layout=x.layout, device=x.device)
class MaskBlock(nn.Module): expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(MaskBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
#self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
self.equalInOut = (inplanes == planes)
self.activation = Identity()
self.activation.register_backward_hook(self._fisher)
self.register_buffer('mask', None)
self.input_shape = None
self.output_shape = None
self.flops = None
self.params = None
self.in_channels = inplanes
self.out_channels = planes
self.got_shapes = False
# Fisher method is called on backward passes
self.running_fisher = 0
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
if self.mask is not None:
out = out * self.mask[None, :, None, None]
else:
self._create_mask(x, out)
out = self.activation(out)
self.act = out
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
def _create_mask(self, x, out):
self.mask = x.new_ones(out.shape[1])
self.input_shape = x.size()
self.output_shape = out.size()
def _fisher(self, _, notused2, grad_output):
act = self.act.detach()
grad = grad_output[0].detach()
g_nk = (act * grad).sum(-1).sum(-1)
del_k = g_nk.pow(2).mean(0).mul(0.5)
self.running_fisher += del_k
def reset_fisher(self):
self.running_fisher = 0 * self.running_fisher
def cost(self):
in_channels = self.in_channels
out_channels = self.out_channels
middle_channels = int(self.mask.sum().item())
conv1_size = self.conv1.weight.size()
conv2_size = self.conv2.weight.size()
# convs
self.params = in_channels * middle_channels * conv1_size[2] * conv1_size[3] + middle_channels * out_channels * \
conv2_size[2] * conv2_size[3]
# batchnorms, assuming running stats are absorbed
self.params += 2 * in_channels + 2 * middle_channels
# skip
if not self.equalInOut:
self.params += in_channels * out_channels
else:
self.params += 0
def compress_weights(self):
middle_dim = int(self.mask.sum().item())
print(middle_dim)
if middle_dim is not 0:
conv1 = nn.Conv2d(self.in_channels, middle_dim, kernel_size=3, stride=self.stride, padding=1, bias=False)
conv1.weight = nn.Parameter(self.conv1.weight[self.mask == 1, :, :, :])
# Batch norm 2 changes
bn2 = nn.BatchNorm2d(middle_dim)
bn2.weight = nn.Parameter(self.bn2.weight[self.mask == 1])
bn2.bias = nn.Parameter(self.bn2.bias[self.mask == 1])
bn2.running_mean = self.bn2.running_mean[self.mask == 1]
bn2.running_var = self.bn2.running_var[self.mask == 1]
conv2 = nn.Conv2d(middle_dim, self.out_channels, kernel_size=3, stride=1, padding=1, bias=False)
conv2.weight = nn.Parameter(self.conv2.weight[:, self.mask == 1, :, :])
if middle_dim is 0:
conv1 = Zero()
bn2 = Zero()
conv2 = ZeroMake(channels=self.out_channels, spatial=self.stride)
self.conv1 = conv1
self.conv2 = conv2
self.bn2 = bn2
if middle_dim is not 0:
self.mask = torch.ones(middle_dim)
else:
self.mask = torch.ones(1)
class PoseResNet(nn.Module):
def __init__(self, block, layers, heads, head_conv, **kwargs):
self.inplanes = 64
self.deconv_with_bias = False
self.heads = heads
super(PoseResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
# used for deconv layers
self.deconv_layers = self._make_deconv_layer(
3,
[256, 256, 256],
[4, 4, 4],
)
for head in sorted(self.heads):
num_output = self.heads[head]
if head_conv > 0:
fc = nn.Sequential(
nn.Conv2d(256, head_conv,
kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(head_conv, num_output,
kernel_size=1, stride=1, padding=0))
else:
fc = nn.Conv2d(
in_channels=256,
out_channels=num_output,
kernel_size=1,
stride=1,
padding=0
)
self.__setattr__(head, fc)
# Count params that don't exist in blocks (conv1, bn1, fc)
self.fixed_params = len(self.conv1.weight.view(-1)) + len(self.bn1.weight) + len(self.bn1.bias)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def _get_deconv_cfg(self, deconv_kernel, index):
if deconv_kernel == 4:
padding = 1
output_padding = 0
elif deconv_kernel == 3:
padding = 1
output_padding = 1
elif deconv_kernel == 2:
padding = 0
output_padding = 0
return deconv_kernel, padding, output_padding
def _make_deconv_layer(self, num_layers, num_filters, num_kernels):
assert num_layers == len(num_filters), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
assert num_layers == len(num_kernels), \
'ERROR: num_deconv_layers is different len(num_deconv_filters)'
layers = []
for i in range(num_layers):
kernel, padding, output_padding = \
self._get_deconv_cfg(num_kernels[i], i)
planes = num_filters[i]
layers.append(
nn.ConvTranspose2d(
in_channels=self.inplanes,
out_channels=planes,
kernel_size=kernel,
stride=2,
padding=padding,
output_padding=output_padding,
bias=self.deconv_with_bias))
layers.append(nn.BatchNorm2d(planes, momentum=BN_MOMENTUM))
layers.append(nn.ReLU(inplace=True))
self.inplanes = planes
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.deconv_layers(x)
#print("x:",x)
ret = {}
#print("selfheads:",self.heads){'hm': 13, 'wh': 2, 'reg': 2}
for head in self.heads:
ret[head] = self.__getattr__(head)(x)
return [ret]
def init_weights(self, num_layers, pretrained=True):
if pretrained:
# print('=> init resnet deconv weights from normal distribution')
for _, m in self.deconv_layers.named_modules():
if isinstance(m, nn.ConvTranspose2d):
# print('=> init {}.weight as normal(0, 0.001)'.format(name))
# print('=> init {}.bias as 0'.format(name))
nn.init.normal_(m.weight, std=0.001)
if self.deconv_with_bias:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
# print('=> init {}.weight as 1'.format(name))
# print('=> init {}.bias as 0'.format(name))
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# print('=> init final conv weights from normal distribution')
for head in self.heads:
final_layer = self.__getattr__(head)
#print("*************final_layer:",final_layer)
for i, m in enumerate(final_layer.modules()):
#print("m:",m)
if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# print('=> init {}.weight as normal(0, 0.001)'.format(name))
# print('=> init {}.bias as 0'.format(name))
if m.weight.shape[0] == self.heads[head]:
if 'hm' in head:
nn.init.constant_(m.bias, -2.19)
else:
nn.init.normal_(m.weight, std=0.001)
nn.init.constant_(m.bias, 0)
else:
print('=> imagenet pretrained model dose not exist')
print('=> please download it first')
raise ValueError('imagenet pretrained model does not exist')
def get_pose_net(num_layers, heads, head_conv,mask=False): #block_class, layers = resnet_spec[num_layers] print("heads:",heads)
if mask==1:
model = PoseResNet(MaskBlock, [2, 2, 2, 2], heads, head_conv=head_conv)
# for name, parameters in model.named_parameters():
# print(name, ' \t:', parameters.size())
else:
model = PoseResNet(BasicBlock, [2,2,2,2], heads, head_conv=head_conv)
#print("model:", model)
model.init_weights(num_layers= num_layers,pretrained=True)
return model
this is my code ,can you give me some advices?? thanks