insightface icon indicating copy to clipboard operation
insightface copied to clipboard

how to modify the tools/scrfd.py to enable support batch-size input data?

Open chunniunai220ml opened this issue 4 years ago • 3 comments

hi, how to modify the tools/scrfd.py to enable support batch-size input data?

and i have converted the scrfd.onnx with dynamic batch-size, but I can not get the correct post-process results I change the scrfd .py (mainly forward and detect function) as follow:

def forward(self, bs_imgs, thresh,batch_size):
    scores_list = []
    bboxes_list = []
    kpss_list = []

    # input_size = tuple(bs_imgs[0].shape[0:2][::-1])
    # print('input_size',input_size)
    blob = cv2.dnn.blobFromImages(bs_imgs, 1.0 / 128, self.input_size, (127.5, 127.5, 127.5), swapRB=True)
    st = time.time()
    net_outs = self.session.run(self.output_names, {self.input_name: blob})
    # net_outs=[net_outs[i].reshape(batch_size,-1,net_outs[i].shape[-1]) for i in range(len(net_outs))]

    # print('scrfd inference time...',time.time()-st)
    stfp = time.time()
    input_height = blob.shape[2]
    input_width = blob.shape[3]
    fmc = self.fmc
    self.total_centers=[]
    for idx, stride in enumerate(self._feat_stride_fpn): #merge outputs

        #####get preds
        scores_preds = net_outs[idx]
        bbox_preds = net_outs[idx + fmc]* stride
        kps_preds = net_outs[idx + fmc * 2] * stride

        ####together preds with each branch outputs, and reshape to batch-size 
        scores_list.append(scores_preds.reshape(batch_size,-1,1)) #(bs*640*480/)
        bboxes_list.append(bbox_preds.reshape(batch_size,-1,4))
        kpss_list.append(kps_preds.reshape(batch_size,-1,10))
        # for n in range
        ###get anchor_centers in each branch
        height = input_height // stride
        width = input_width // stride
        key = (height, width, stride)
        if key in self.center_cache:
            self.anchor_centers = self.center_cache[key]
        else:

            self.anchor_centers = np.stack(np.mgrid[:height, :width][::-1], axis=-1).astype(np.float32)
            self.anchor_centers = (self.anchor_centers * stride).reshape((-1, 2))
            if self._num_anchors > 1:
                self.anchor_centers = np.stack([self.anchor_centers] * self._num_anchors, axis=1).reshape((-1, 2))
            if len(self.center_cache) < 100:
                self.center_cache[key] = self.anchor_centers

        print(idx, self.anchor_centers.shape)
        self.total_centers.append(self.anchor_centers)
    return scores_list, bboxes_list, kpss_list


def prepare_img(self,img,input_size=None):
    assert input_size is not None or self.input_size is not None
    input_size = self.input_size if input_size is None else input_size
    im_ratio = float(img.shape[0]) / img.shape[1]
    model_ratio = float(input_size[1]) / input_size[0]
    if im_ratio > model_ratio:
        new_height = input_size[1]
        new_width = int(new_height / im_ratio)
    else:
        new_width = input_size[0]
        new_height = int(new_width * im_ratio)
    self.det_scale = float(new_height) / img.shape[0]
    resized_img = cv2.resize(img, (new_width, new_height))
    det_img = np.zeros((input_size[1], input_size[0], 3), dtype=np.uint8)
    det_img[:new_height, :new_width, :] = resized_img
    # print('det size: ',det_img.shape)
    return det_img

def detect(self, bs_imgs, thresh=0.5, max_num=0, metric='default'):

    bs_imgs = [self.prepare_img(i) for i in bs_imgs]
    stfor=time.time()
    scores_list, bboxes_list, kpss_list = self.forward(bs_imgs, thresh,len(bs_imgs))

    bs_det=[]
    bs_kpss=[]

    for i in range(len(bs_imgs)):
        # for j in range(len(scores_list)):
        pos_scores_p3, pos_bboxes_p3, pos_kpss_p3=self.filter_by_scores(scores_list[0][i],
                                                                         bboxes_list[0][i],
                                                                         kpss_list[0][i],
                                                                         self.total_centers[0])
        pos_scores_p4, pos_bboxes_p4, pos_kpss_p4 = self.filter_by_scores(scores_list[1][i],
                                                                           bboxes_list[1][i],
                                                                           kpss_list[1][i],
                                                                           self.total_centers[1])
        pos_scores_p5, pos_bboxes_p5, pos_kpss_p5 = self.filter_by_scores(scores_list[2][i],
                                                                           bboxes_list[2][i],
                                                                           kpss_list[2][i],
                                                                           self.total_centers[2])

        i_score = np.vstack([pos_scores_p3,pos_scores_p4,pos_scores_p5])
        i_bbox = np.vstack([pos_bboxes_p3,pos_bboxes_p4,pos_bboxes_p5]) / self.det_scale
        i_kps = np.vstack([pos_kpss_p3,pos_kpss_p4,pos_kpss_p5]) / self.det_scale

        # topK
        scores_ravel = i_score.ravel()
        order = scores_ravel.argsort()[::-1]
        pre_det = np.hstack((i_bbox, i_score)).astype(np.float32, copy=False)
        pre_det = pre_det[order, :]
        stnms = time.time()
        keep = self.nms(pre_det)
        print('keeeppppp,order....',len(keep),len(order))
        # print('nms time : ',time.time()-stnms)
        det = pre_det[keep, :]
        kpss = i_kps[order, :, :]
        kpss = kpss[keep, :, :]
        # else:
        #     kpss = None
        if max_num > 0 and det.shape[0] > max_num:
            img=bs_imgs[i]
            area = (det[:, 2] - det[:, 0]) * (det[:, 3] -
                                              det[:, 1])
            img_center = img.shape[0] // 2, img.shape[1] // 2
            offsets = np.vstack([
                (det[:, 0] + det[:, 2]) / 2 - img_center[1],
                (det[:, 1] + det[:, 3]) / 2 - img_center[0]
            ])
            offset_dist_squared = np.sum(np.power(offsets, 2.0), 0)
            if metric == 'max':
                values = area
            else:
                values = area - offset_dist_squared * 2.0  # some extra weight on the centering
            bindex = np.argsort(
                values)[::-1]  # some extra weight on the centering
            bindex = bindex[0:max_num]
            det = det[bindex, :]
            if kpss is not None:
                kpss = kpss[bindex, :]

        bs_det.append(det)
        bs_kpss.append(kpss)
    return bs_det, bs_kpss

chunniunai220ml avatar Sep 07 '21 12:09 chunniunai220ml

Did you ever find this out? We'd like to do the same thing. Batch inference has been requested several times, but the authors don't respond.

dreamflasher avatar Nov 24 '21 17:11 dreamflasher

Hi, were you able to solve the issue?, I want to test batch wise on a custom data.

tanya-suri avatar Mar 25 '22 06:03 tanya-suri

Hey guys, any updates here?

HoracceFeng avatar Jun 21 '22 07:06 HoracceFeng