Torch-Pruning icon indicating copy to clipboard operation
Torch-Pruning copied to clipboard

Pruned Yolov8 model not loading?

Open ashray21 opened this issue 3 months ago • 3 comments

@VainF I have trained a custom YOLOv8 model. After training i have successfully pruned the model.

 for name, param in model.model.named_parameters():
        param.requires_grad = True

replace_c2f_with_c2f_v2(model.model)

model.model.eval()
example_inputs = torch.randn(1, 3, 800, 800).to(model.device)
imp = tp.importance.MagnitudeImportance(p=2)  # L2 norm pruning

ignored_layers = []
unwrapped_parameters = []

modules_list = list(model.model.modules())
for i, m in enumerate(modules_list):
    if isinstance(m, (Detect,)):
        ignored_layers.append(m)

iterative_steps = 1  # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model.model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    ch_sparsity=0.5,  # remove 50% channels
    ignored_layers=ignored_layers,
    unwrapped_parameters=unwrapped_parameters
)
base_macs, base_nparams = tp.utils.count_ops_and_params(model.model, example_inputs)
pruner.step()

pruned_macs, pruned_nparams = tp.utils.count_ops_and_params(pruner.model, example_inputs)
print("Before Pruning: MACs=%f G, #Params=%f G" % (base_macs / 1e9, base_nparams / 1e9))
print("After Pruning: MACs=%f G, #Params=%f G" % (pruned_macs / 1e9, pruned_nparams / 1e9))

After I save pruned model. Is it the correct way to save pruned model?

torch.save(pruner.model, "prune.pt")

After saving model I load model and it showing the following error:

 pruned_model = YOLO("prune.pt")

 AttributeError                            Traceback (most recent call last)
Cell In[88], [line 1](vscode-notebook-cell:?execution_count=88&line=1)
----> [1](vscode-notebook-cell:?execution_count=88&line=1) pruned_model = YOLO("prune.pt")

File [~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:94](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:94), in Model.__init__(self, model, task)
     [92](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:92)     self._new(model, task)
     [93](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:93) else:
---> [94](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:94)     self._load(model, task)

File [~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:140](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:140), in Model._load(self, weights, task)
    [138](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:138) suffix = Path(weights).suffix
    [139](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:139) if suffix == '.pt':
--> [140](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:140)     self.model, self.ckpt = attempt_load_one_weight(weights)
    [141](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:141)     self.task = self.model.args['task']
    [142](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/engine/model.py:142)     self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)

File [~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:609](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:609), in attempt_load_one_weight(weight, device, inplace, fuse)
    [607](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:607) def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
    [608](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:608)     """Loads a single model weights."""
--> [609](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:609)     ckpt, weight = torch_safe_load(weight)  # load ckpt
    [610](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:610)     args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))}  # combine model and default args, preferring model args
    [611](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:611)     model = (ckpt.get('ema') or ckpt['model']).to(device).float()  # FP32 model

File [~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:548](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/ultralytics/nn/tasks.py:548), in torch_safe_load(weight)
...
   [1413](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/torch/serialization.py:1413)         pass
   [1414](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/torch/serialization.py:1414) mod_name = load_module_mapping.get(mod_name, mod_name)
-> [1415](https://file+.vscode-resource.vscode-cdn.net/home/ashi/Documents/projects/~/.virtualenvs/ashray_dev/lib/python3.10/site-packages/torch/serialization.py:1415) return super().find_class(mod_name, name)

AttributeError: Can't get attribute '__main__' on <module 'builtins' (built-in)>

Also I need to ask is it necessary to train again on a pruned model?

ashray21 avatar Apr 01 '24 06:04 ashray21

have you solove this question? i had met same problem

luoshiyong avatar Apr 08 '24 03:04 luoshiyong

have you solove this question? i had met same problem

@luoshiyong Not yet. How did you saved your model ?

ashray21 avatar Apr 08 '24 06:04 ashray21

Hi,

I managed to use the example demo of Yolov8. Then I'm able to load and run the pruned model by using :

from ultralytics.nn.tasks import attempt_load_one_weight model, _ = attempt_load_one_weight(weights)

But for that I need to import the C2f_v2 class as it's not part of ultralytics YOLOV8 : class C2f_v2(nn.Module): # CSP Bottleneck with 2 convolutions def init(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion super().init() self.c = int(c2 * e) # hidden channels self.cv0 = Conv(c1, self.c, 1, 1) self.cv1 = Conv(c1, self.c, 1, 1) self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2) self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))

def forward(self, x):
    # y = list(self.cv1(x).chunk(2, 1))
    y = [self.cv0(x), self.cv1(x)]
    y.extend(m(y[-1]) for m in self.m)
    return self.cv2(torch.cat(y, 1))
    

Hope it will help!

CloudRider-pixel avatar Apr 16 '24 09:04 CloudRider-pixel