super-gradients icon indicating copy to clipboard operation
super-gradients copied to clipboard

Store model type and number of classes as metadata in the '.pt' file

Open mikel-brostrom opened this issue 1 year ago • 3 comments

🚀 Feature Request

Manual work has to be done to get a custom model working when it could be completely automatized.

self.model = models.get(
    <type_of_yolo_model>,
    num_classes=<the-number-of-classes-in-your-model>,
    checkpoint_path =<path-to-your-weights.pt>
)

Proposed Solution (Optional)

You could save the model and its associated metadata like this:

torch.save({
    'state_dict': model.state_dict(),
    'metadata': {'model_type': 'yolo_nas_s', 'classes_count': 80}
}, '/pth/to/your/model.pt')

Then you could load it like this:

checkpoint = torch.load('/pth/to/your/model.pt')
metadata = checkpoint['metadata']

self.model = models.get(
    metadata['model_type'],
    num_classes=metadata['classes_count']>,
    checkpoint_path ='/pth/to/your/model.pt'
)

Avoiding any manual work on custom model loading. I get quite a lot of requests for this here

mikel-brostrom avatar Jul 20 '23 06:07 mikel-brostrom