sahi icon indicating copy to clipboard operation
sahi copied to clipboard

Allow detectron model to manually setup metadata instead of setting it as `MetadataCatalog.get(cfg.DATASETS.TRAIN[0])`

Open snapcart-ruben opened this issue 2 years ago • 2 comments

Firstly, thanks for creating this repo as I was also trying to mimic video buffer just to achieve the per slice detection of static image. It was somehow a success but speed and final accuracy isn't achieved.

As per my title, I'd like to suggest to allow manual setup of metadata to the Detectron2DetectionModel.load_model method. Specifically, https://github.com/obss/sahi/blob/b93132d175a729079f526452587fa7fb1df5f1c9/sahi/model.py#L539

Though manually setting category_mapping parameter is a way to update the final result, it doesn't guarantee the 100% capture rate of the prediction, albeit over-prediction due to missing metadata.

Specifically, the warning generated by missing metadata:

WARNING - sahi.model -   Attribute 'thing_classes' does not exist in the metadata of dataset 'dataset_name': metadata is empty.

As a solution, I have to add (override) this line into the load_model method:

MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).set(thing_classes=['class_A','class_B'])

In some cases (especially in prediction),

`cfg.DATASETS.TEST[0]` is preferred over the former.

And to some extent, the cfg.DATASETS need not to be loaded at all.

  1. In general, (aside from Detectron) it would be great to have independent __init__(*args, **kwargs) method for each subclass (inheriting from DetectionModel class) preserving the parent (super.__init__()).

  2. If (1) isn't a good option, On Detectron, to add thing_classes: List[str] as a parameter with catch code if None is provided:

class Detectron2DetectionModel(DetectionModel):
    def load_model(self):
          ...
      # detectron2 category mapping
      if self.category_mapping is None:
      |   try:  # try to parse category names from metadata
      |       if not self.thing_classes:
      |          metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).thing_classes
      |       else:
      |          metadata = self.thing_classes
      |          # setting `thing_classes` will not allow other trained labels to be outputted by the predictor
      |          MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).set(thing_classes=self.thing_classes) 
      |   ...
      | # in addition, setting `MODEL.ROI_HEADS.NUM_CLASSES` will ensure `self.category_mapping` 
      | # to not look for a non-existing label, causing an IndexError for `None` types
      |
      if not len(cfg.MODEL.ROI_HEADS.NUM_CLASSES) and self.thing_classes:
          cfg.MODEL.ROI_HEADS.NUM_CLASSES = self.num_categories

Thank you and more powers to the team!

snapcart-ruben avatar Sep 07 '22 08:09 snapcart-ruben