super-gradients
super-gradients copied to clipboard
Store model type and number of classes as metadata in the '.pt' file
🚀 Feature Request
Manual work has to be done to get a custom model working when it could be completely automatized.
self.model = models.get(
<type_of_yolo_model>,
num_classes=<the-number-of-classes-in-your-model>,
checkpoint_path =<path-to-your-weights.pt>
)
Proposed Solution (Optional)
You could save the model and its associated metadata like this:
torch.save({
'state_dict': model.state_dict(),
'metadata': {'model_type': 'yolo_nas_s', 'classes_count': 80}
}, '/pth/to/your/model.pt')
Then you could load it like this:
checkpoint = torch.load('/pth/to/your/model.pt')
metadata = checkpoint['metadata']
self.model = models.get(
metadata['model_type'],
num_classes=metadata['classes_count']>,
checkpoint_path ='/pth/to/your/model.pt'
)
Avoiding any manual work on custom model loading. I get quite a lot of requests for this here