nebuly
nebuly copied to clipboard
How to save and load yolov5 optimized
Hi, I have followed the speedster yolov5 notebook, I would just like to know how to save the model and then load it again for inference.
Thanks
Hello @ozayr, once optimized the model you must just save its compiled version, i.e. the content of the variable model_optimized
in the notebook (alternatively by the end of the notebook it should also be stored in model.model.model.core
). You can save it by running model_optimized.save("path_to_save_dir")
.
Loading the model is slightly trickier:
- Instantiate a YOLOv5 model
yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, force_reload=True)
- Load the optimized model running
from nebullvm.operations.inference_learners.base import LearnerMetadata
optimized_model = LearnerMetadata.read("model_save_path").load_model("model_save_path")
Replace "model_save_path"
with the path to the directory where you previously saved the optimized model.
- Create a class for wrapping the compiled model into a PyTorch model. You can copy-paste the class shown in the notebook
class OptimizedYolo(torch.nn.Module):
def __init__(self, optimized_core, head_layer):
super().__init__()
self.core = optimized_core
self.head = head_layer
def forward(self, x, *args, **kwargs):
x = list(self.core(x)) # it's a tuple
return self.head(x)
- Extract the last layer of the original YOLO model since it wasn't compiled
last_layer = list(yolo_model.model.model.model.children())[-1]
- Instantiate an OptimizedYolo object and replace the core model in the original
yolo_model
with the optimized yolo
yolo_optimized = OptimizedYolo(optimized_model, last_layer)
yolo_model.model.model = yolo_optimized
Tell me if something is unclear. Feel free to add a section to the Yolo notebook about saving and loading the model.
In addition if you think it could be useful for the rest of the community you can add in speedster/api a utils.py module where you can implement the two functions for saving and loading the Yolo model. Basically you should implement two functions wrapping the instructions above.