torch-model-compression icon indicating copy to clipboard operation
torch-model-compression copied to clipboard

剪枝分割网络报错-bisenetv2

Open GeneralJing opened this issue 2 years ago • 17 comments

File "examples/torchpruner/prune_by_class_bisenetv2.py", line 39, in <module>
    model, context = torchpruner.set_cut(model, result)
  File "site-packages/torchpruner-0.0.1-py3.8.egg/torchpruner/model_pruner.py", line 71, in set_cut
  File "site-packages/torchpruner-0.0.1-py3.8.egg/torchpruner/module_pruner/pruners.py", line 188, in set_cut
  File "site-packages/torchpruner-0.0.1-py3.8.egg/torchpruner/module_pruner/pruners.py", line 60, in set_cut
  File "site-packages/torchpruner-0.0.1-py3.8.egg/torchpruner/module_pruner/prune_function.py", line 42, in set_cut_tensor
IndexError: index 78 is out of bounds for dimension 0 with size 78

GeneralJing avatar Oct 09 '21 09:10 GeneralJing

有解决的消息了吗?是跟分组卷积有关吗?

GeneralJing avatar Oct 12 '21 09:10 GeneralJing

呃这两天有别的事情,还没看多久,晚上我再自己测测

gdh1995 avatar Oct 12 '21 09:10 gdh1995

嗯 好的 辛苦大佬,有消息及时回复下,自己手动剪,感觉比较步骤比较繁琐

GeneralJing avatar Oct 12 '21 09:10 GeneralJing

能贴你的 prune_by_class_bisenetv2.py 吗?我创建 lib/models/bisenetv2.py#BiSeNetV2(19)graph.build_graph 没跑过去。bisnet 是刚从 https://github.com/CoinCheung/BiSeNet 找的。

gdh1995 avatar Oct 12 '21 10:10 gdh1995

import sys

sys.path.append("..")
import torch
import torchpruner
import torchvision
import numpy as np
from bisenetv2 import BiSeNetV2

#以下代码示例了对每一个BN层去除其weight系数绝对值前20%小的层

#加载模型
model = torchvision.models.vgg11_bn()
print('model-origin:', model)
#jzy
#model = BiSeNetV2(n_classes=9, aux_mode='pred')
model = BiSeNetV2(n_classes=9)
model.load_state_dict(torch.load('/home/zxz/torch-model-compression/examples/torchpruner/model_final.pth', map_location='cuda'), strict=False)
print('model-origin:', model)

# 创建ONNXGraph对象,绑定需要被剪枝的模型
graph = torchpruner.ONNXGraph(model)
##build ONNX静态图结构,需要指定输入的张量
graph.build_graph(inputs=(torch.zeros(8, 3, 320, 640),))

# 遍历所有的Module
for key in graph.modules:
    module = graph.modules[key]
    # 如果该module对应了BN层
    if isinstance(module.nn_object, torch.nn.BatchNorm2d):
        # 获取该对象
        nn_object = module.nn_object
        # 排序,取前20%小的权重值对应的index
        weight = nn_object.weight.detach().cpu().numpy()
        index = np.argsort(np.abs(weight))[: int(weight.shape[0] * 0.2)]
        print('index:', index)
        result = module.cut_analysis("weight", index=index, dim=0)
        model, context = torchpruner.set_cut(model, result)

# 新的model即为剪枝后的模型
print('model-pruned:', model)

这个代码。

GeneralJing avatar Oct 12 '21 10:10 GeneralJing

comment框是markdown格式的,我的注释在这很奇怪

GeneralJing avatar Oct 12 '21 10:10 GeneralJing

markdown的多行代码块语法是 三个 ` 连着表示开头和结尾,比如

# // 开头
# ``` [ + 空格 + 语言名]
# 具体内容
# // 结尾
# ``` 

gdh1995 avatar Oct 12 '21 10:10 gdh1995

嗯嗯 有空看看mk语法,大佬看问题 嘿嘿

GeneralJing avatar Oct 12 '21 10:10 GeneralJing

你用的 bisenetv2.py 是哪个文件?BiSeNet-master/lib/models/bisenetv2.py 还是 BiSeNet-master/old/bisenetv2/bisenetv2.py ?

另外我用官方仓库的模型(old/README.md的百度网盘文件 model_final.pth)好像加载不了,大概是参数不一致……厚颜来要你的原始模型了,能发的话,网盘或者发到 [email protected] 都行;不方便的话,我就再看看。我现在因为懒得自己在coco上训 练,缺少实际可用的参数张量,卡在前边某一个分组卷积的 cut_analysis 步骤了Orz

gdh1995 avatar Oct 12 '21 14:10 gdh1995

用的是BiSeNet-master/lib/models/bisenetv2.py这个,模型的话在公司,可以给你发一个训练几个epoch的版本,因为是其他人训练的,也不方便给最新的。得明天给你了。

GeneralJing avatar Oct 12 '21 14:10 GeneralJing

嗯谢谢!

gdh1995 avatar Oct 12 '21 14:10 gdh1995

Traceback (most recent call last): File "G:/py/torch-model-compression-main/examples/torchpruner/prune_and_recovery.py", line 16, in graph.build_graph(inputs=(torch.zeros(1, 3, 224, 224),)) File "G:\py\torch-model-compression-main\torchpruner\graph.py", line 458, in build_graph graph, params_dict, torch_out = torch.onnx.utils._model_to_graph( TypeError: _model_to_graph() got an unexpected keyword argument '_retain_param_name'

好几个都是 加载模型 时出现上述错误。我是刚入门的小白,请问这个怎么处理,程序跑不通,不知道怎么改这个。 model = torchvision.models.resnet50()

jzy-hxf avatar Nov 18 '21 01:11 jzy-hxf

@jzy-hxf 抱歉之前没注意消息。具体到你这个问题,是torch版本比较新(1.11还是多少以上)造成的。你把这个项目里出现的 _retain_param_name 都去掉就行了,不影响结果。

gdh1995 avatar Dec 07 '21 02:12 gdh1995

@GeneralJing 抱歉我好久没注意这个。我确认了几遍代码,应该是 examples/torchpruner/prune_by_class.py 写的有问题,每次执行 torchpruner.set_cutgraph 的部分信息会过时,所以需要重新创建 graphgraph.modules 是稳定的,可以预先算好 keys

for key in list(graph.modules):
    # ...
    model, context = torchpruner.set_cut(model, result)
    graph = torchpruner.ONNXGraph(model)  # 本行可以省略
    graph.build_graph(inputs=(torch.zeros(1, 3, 224, 224),))

我一会去改一下示例代码。

gdh1995 avatar Dec 07 '21 07:12 gdh1995

“下个周末看一下”,重新定义了下个周末哈哈哈哈,好的,多谢。

GeneralJing avatar Dec 07 '21 07:12 GeneralJing

参考你的例子,对BiSeNetv2模型剪枝,报如下错误,这个怎么解决,能看一下吗

index: [34 53 31 46 12 49 36 60 18 68 39 57 67 13 14]
key: self.segment.S3.0.dwconv1.1
Traceback (most recent call last):
  File "model_pruning.py", line 33, in <module>
    result = module.cut_analysis("weight", index=index, dim=0)
  File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/graph.py", line 333, in cut_analysis
  File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/graph.py", line 246, in cut_analysis
  File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/graph.py", line 269, in cut_analysis_with_mask
  File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/operator/onnx_operator.py", line 489, in analysis
  File "/usr/local/lib/python3.6/dist-packages/torchpruner-0.1.0-py3.6.egg/torchpruner/mask_utils.py", line 276, in indexs
RuntimeError: All the data is masked

Durobert avatar May 11 '22 10:05 Durobert

我的代码如下

import imp
import os
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(BASE_DIR, '..'))
import torch
import torchpruner
import numpy as np
from models.bisenetv2 import BiSeNetV2


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载模型
model = BiSeNetV2(n_classes=4)
checkpoint = torch.load('../results/05-05_01-59/checkpoint_best.pkl', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

graph = torchpruner.ONNXGraph(model)
graph.build_graph(inputs=(torch.zeros(1, 3, 640, 512),))

for key in list(graph.modules):
    module = graph.modules[key]
    if isinstance(module.nn_object, torch.nn.BatchNorm2d):
        nn_object = module.nn_object
        weight = nn_object.weight.detach().cpu().numpy()
        index = np.argsort(np.abs(weight))[: int(weight.shape[0] * 0.2)]
        print('index:', index)
        print('key:', key)
        result = module.cut_analysis("weight", index=index, dim=0)
        model, context = torchpruner.set_cut(model, result)
        graph = torchpruner.ONNXGraph(model)
        graph.build_graph(inputs=(torch.zeros(1, 3, 640, 512),))

print('model-pruned:', model)
torch.save(model, '../results/05-05_01-59/checkpoint_best_pruning.pkl')

Durobert avatar May 11 '22 10:05 Durobert