sahi
sahi copied to clipboard
Allow detectron model to manually setup metadata instead of setting it as `MetadataCatalog.get(cfg.DATASETS.TRAIN[0])`
Firstly, thanks for creating this repo as I was also trying to mimic video buffer just to achieve the per slice detection of static image. It was somehow a success but speed and final accuracy isn't achieved.
As per my title, I'd like to suggest to allow manual setup of metadata to the Detectron2DetectionModel.load_model
method. Specifically, https://github.com/obss/sahi/blob/b93132d175a729079f526452587fa7fb1df5f1c9/sahi/model.py#L539
Though manually setting category_mapping
parameter is a way to update the final result, it doesn't guarantee the 100% capture rate of the prediction, albeit over-prediction due to missing metadata.
Specifically, the warning
generated by missing metadata:
WARNING - sahi.model - Attribute 'thing_classes' does not exist in the metadata of dataset 'dataset_name': metadata is empty.
As a solution, I have to add (override) this line into the load_model
method:
MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).set(thing_classes=['class_A','class_B'])
In some cases (especially in prediction),
`cfg.DATASETS.TEST[0]` is preferred over the former.
And to some extent, the cfg.DATASETS
need not to be loaded at all.
-
In general, (aside from Detectron) it would be great to have independent
__init__(*args, **kwargs)
method for each subclass (inheriting fromDetectionModel
class) preserving the parent (super.__init__()
). -
If (1) isn't a good option, On Detectron, to add
thing_classes: List[str]
as a parameter with catch code ifNone
is provided:
class Detectron2DetectionModel(DetectionModel):
def load_model(self):
...
# detectron2 category mapping
if self.category_mapping is None:
| try: # try to parse category names from metadata
| if not self.thing_classes:
| metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes
| else:
| metadata = self.thing_classes
| # setting `thing_classes` will not allow other trained labels to be outputted by the predictor
| MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).set(thing_classes=self.thing_classes)
| ...
| # in addition, setting `MODEL.ROI_HEADS.NUM_CLASSES` will ensure `self.category_mapping`
| # to not look for a non-existing label, causing an IndexError for `None` types
|
if not len(cfg.MODEL.ROI_HEADS.NUM_CLASSES) and self.thing_classes:
cfg.MODEL.ROI_HEADS.NUM_CLASSES = self.num_categories
Thank you and more powers to the team!