Torch-Pruning
Torch-Pruning copied to clipboard
Pruned Yolov8 model not loading?
@VainF I have trained a custom YOLOv8 model. After training i have successfully pruned the model.
for name, param in model.model.named_parameters():
param.requires_grad = True
replace_c2f_with_c2f_v2(model.model)
model.model.eval()
example_inputs = torch.randn(1, 3, 800, 800).to(model.device)
imp = tp.importance.MagnitudeImportance(p=2) # L2 norm pruning
ignored_layers = []
unwrapped_parameters = []
modules_list = list(model.model.modules())
for i, m in enumerate(modules_list):
if isinstance(m, (Detect,)):
ignored_layers.append(m)
iterative_steps = 1 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
model.model,
example_inputs,
importance=imp,
iterative_steps=iterative_steps,
ch_sparsity=0.5, # remove 50% channels
ignored_layers=ignored_layers,
unwrapped_parameters=unwrapped_parameters
)
base_macs, base_nparams = tp.utils.count_ops_and_params(model.model, example_inputs)
pruner.step()
pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruner.model, example_inputs)
print("Before Pruning: MACs=%f G, #Params=%f G" % (base_macs / 1e9, base_nparams / 1e9))
print("After Pruning: MACs=%f G, #Params=%f G" % (pruned_macs / 1e9, pruned_nparams / 1e9))
After I save pruned model. Is it the correct way to save pruned model?
torch.save(pruner.model, "prune.pt")
After saving model I load model and it showing the following error:
pruned_model = YOLO("prune.pt")
AttributeError Traceback (most recent call last)
Cell In[88], [line 1](vscode-notebook-cell:?execution_count=88&line=1)
----> [1](vscode-notebook-cell:?execution_count=88&line=1) pruned_model = YOLO("prune.pt")
File [~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:94](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:94), in Model.__init__(self, model, task)
[92](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:92) self._new(model, task)
[93](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:93) else:
---> [94](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:94) self._load(model, task)
File [~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:140](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:140), in Model._load(self, weights, task)
[138](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:138) suffix = Path(weights).suffix
[139](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:139) if suffix == '.pt':
--> [140](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:140) self.model, self.ckpt = attempt_load_one_weight(weights)
[141](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:141) self.task = self.model.args['task']
[142](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:142) self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
File [~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:609](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:609), in attempt_load_one_weight(weight, device, inplace, fuse)
[607](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:607) def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
[608](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:608) """Loads a single model weights."""
--> [609](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:609) ckpt, weight = torch_safe_load(weight) # load ckpt
[610](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:610) args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args
[611](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:611) model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
File [~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:548](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:548), in torch_safe_load(weight)
...
[1413](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/torch/serialization.py:1413) pass
[1414](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/torch/serialization.py:1414) mod_name = load_module_mapping.get(mod_name, mod_name)
-> [1415](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/torch/serialization.py:1415) return super().find_class(mod_name, name)
AttributeError: Can't get attribute '__main__' on <module 'builtins' (built-in)>
Also I need to ask is it necessary to train again on a pruned model?
have you solove this question? i had met same problem
have you solove this question? i had met same problem
@luoshiyong Not yet. How did you saved your model ?
Hi,
I managed to use the example demo of Yolov8. Then I'm able to load and run the pruned model by using :
from ultralytics.nn.tasks import attempt_load_one_weight model, _ = attempt_load_one_weight(weights)
But for that I need to import the C2f_v2 class as it's not part of ultralytics YOLOV8 : class C2f_v2(nn.Module): # CSP Bottleneck with 2 convolutions def init(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion super().init() self.c = int(c2 * e) # hidden channels self.cv0 = Conv(c1, self.c, 1, 1) self.cv1 = Conv(c1, self.c, 1, 1) self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2) self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))
def forward(self, x):
# y = list(self.cv1(x).chunk(2, 1))
y = [self.cv0(x), self.cv1(x)]
y.extend(m(y[-1]) for m in self.m)
return self.cv2(torch.cat(y, 1))
Hope it will help!
@ashray21 hey, have you solved this?
put the code of c2f_v2 in module.py or block.py (new and old version of yolov8 is different), or put it in your main file where loading function exists.