mega.pytorch icon indicating copy to clipboard operation
mega.pytorch copied to clipboard

Return bbox class confidence vectors

Open AlbertoSabater opened this issue 3 years ago • 1 comments

Hi! When performing object detection inference, I would like to return all the class confidence vectors, not only the best score/class for each bbox. Can you provide me some intuition about the files I should modify to get them?

Thank you in advance, Alberto

AlbertoSabater avatar Oct 01 '20 15:10 AlbertoSabater

I have finally managed to extract logits along with the detections. To do so, I had to perform NMS to all the detections at the same time without splitting by class. Since your method regresses one bounding box for each class given a detection, I had to choose the class with higher confidence before selecting its associated bounding box. This is the final code I modified in case if someone is interested in:

    def filter_results_v2(self, boxlist, num_classes):
        """Returns bounding-box detection results by thresholding on scores and
        applying non-maximum suppression (NMS).
        """
        boxes = boxlist.bbox.reshape(-1, num_classes * 4)
        scores = boxlist.get_field("scores").reshape(-1, num_classes)

        device = scores.device
        result = []
        # Apply threshold on detection probabilities and apply NMS
        
        best_scores, best_class_inds = torch.max(scores[:, 1:], 1)

        inds_v2 = best_scores > self.score_thresh
        
        best_scores = best_scores[inds_v2]
        best_logits = scores[inds_v2, 1:]
        best_class_inds = best_class_inds[inds_v2]
        boxes = boxes[inds_v2]
        if len(boxes) > 0:
            boxes = torch.stack([ boxes[i, best_class_inds[ind] * 4 : (best_class_inds[ind] + 1) * 4] for i,ind in enumerate(range(len(boxes))) ])
        else: boxes = torch.full((0,4), 4, dtype=torch.float32, device=device)

        boxlist_for_class = BoxList(boxes, boxlist.size, mode="xyxy")
        boxlist_for_class.add_field("scores", best_scores)
        boxlist_for_class.add_field("labels", best_class_inds)
        boxlist_for_class.add_field("logits", best_logits)
        boxlist_for_class = boxlist_nms(boxlist_for_class, self.nms)
        
        result = boxlist_for_class
        number_of_detections = len(result)

        # Limit to max_per_image detections **over all classes**
        if number_of_detections > self.detections_per_img > 0:
            cls_scores = result.get_field("scores")
            image_thresh, _ = torch.kthvalue(
                cls_scores.cpu(), number_of_detections - self.detections_per_img + 1
            )
            keep = cls_scores >= image_thresh.item()
            keep = torch.nonzero(keep).squeeze(1)
            result = result[keep]
        return result

Code below belongs to this file.

AlbertoSabater avatar Oct 05 '20 08:10 AlbertoSabater