nni
nni copied to clipboard
ModelSpeedup() error
Environment:
- NNI version: 3.0
I have implemented FPGM pruninig to an object detector with FPN and skip connections. ModelSpeedup() doesn't work with my model’s architecture
Here is the code I used:
import torch
from modelsimport build_model
from data.config import cfg
from nni.compression.pruning import FPGMPruner
from nni.common.concrete_trace_utils import concrete_trace
from nni.compression.speedup import ModelSpeedup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = build_model("test", cfg.NUM_CLASSES, width_mult=0.0625).to(device)
config_list = [{
'sparsity_per_layer' : 0.2,
'op_types' : ['Conv2d'],
}, {
'exclude' : True,
'op_names' : [
'loc.0', 'loc.1', 'loc.2', 'loc.3', 'loc.4', 'loc.5',
'conf.0', 'conf.1', 'conf.2', 'conf.3', 'conf.4', 'conf.5'
]
}]
dummy_input = torch.rand(2, 3, 640, 640).to(device)
pruner = FPGMPruner(model, config_list)
_, masks = pruner.compress()
pruner.unwrap_model()
model = ModelSpeedup(model, dummy_input, masks, graph_module=graph_module).speedup_model()
And the error code is this:
IndexError Traceback (most recent call last)
Cell In[8], line 32
29 pruner.unwrap_model()
31 dummy_input_for_trace = torch.rand([1, 3, 640, 640]).to(device)
---> 32 graph_module = concrete_trace(model, {'x': dummy_input_for_trace})
33 #print(masks)
34 model = ModelSpeedup(model, dummy_input, masks, graph_module=graph_module).speedup_model()
File [~/anaconda3/envs/gpu/lib/python3.11/site-packages/nni/common/concrete_trace_utils/concrete_tracer.py:1606](https://vscode-remote+ssh-002dremote-002b160-002e40-002e54-002e160.vscode-resource.vscode-cdn.net/home/gkrispanis/Projects/EResFD-main/~/anaconda3/envs/gpu/lib/python3.11/site-packages/nni/common/concrete_trace_utils/concrete_tracer.py:1606), in concrete_trace(root, concrete_args, use_operator_patch, operator_patch_backlist, forward_function_name, check_args, autowrap_leaf_function, autowrap_leaf_class, leaf_module, fake_middle_class, dce, cpu_offload, trace_twice)
1603 is_training = root.training
1604 root.eval()
-> 1606 graph = tracer.trace(root,
1607 autowrap_leaf_function = autowrap_leaf_function,
1608 autowrap_leaf_class = autowrap_leaf_class,
1609 leaf_module = leaf_module,
1610 fake_middle_class = fake_middle_class,
1611 concrete_args = concrete_args,
1612 use_operator_patch = use_operator_patch,
1613 operator_patch_backlist = operator_patch_backlist,
1614 forward_function_name = forward_function_name,
1615 )
1617 if trace_twice:
1618 graph_check = tracer.trace(root,
1619 autowrap_leaf_function = autowrap_leaf_function,
...
165 if insts[cur].opcode in self.jump_opcodes or (
166 insts[cur].opcode in self.jump_before_opcodes and insts[cur + 1].opcode in self.jump_opcodes):
167 # in executing branch condition
IndexError: list index out of range