miemiedetection
miemiedetection copied to clipboard
关于yolox_head_fast.py里面dynamic_k_matching函数的一个小bug
在参考作者代码对官方pipeline进行修改之后,终于统一了数据格式的问题,但是后来又遇到了一个问题,与合理的gt数目有关,源代码是这样的:
# 如果有预测框(花心大萝卜)匹配到了1个以上的gt时,做特殊处理。 if (anchor_matching_gt > 1).float().sum() > 0: # 首先,找到与花心大萝卜具有最小cost的gt。 # 找到 花心大萝卜 的下标(这是在anchor_matching_gt.shape[N, A]中的下标)。假设有R个花心大萝卜。 indexes = torch.where(anchor_matching_gt > 1) index = torch.stack((indexes[0], indexes[1]), 1) # [R, 2] 每个花心大萝卜2个坐标。第0个坐标表示第几张图片,第1个坐标表示第几个格子。 cost_t = cost.permute(0, 2, 1) # [N, G, A] -> [N, A, G] 转置好提取其cost cost2 = self.gather_nd(cost_t, index) # [R, G] 抽出 R个花心大萝卜 与 gt 两两之间的cost。 cost2 = cost2.permute(1, 0) # [G, R] gt 与 R个花心大萝卜 两两之间的cost。 cost_argmin = cost2.argmin(axis=0) # [R, ] 为 每个花心大萝卜 找到 与其cost最小的gt 的下标
我的代码会在进入这个判断后稳定跑飞,报错位置在求取下标index这一步
indexes = torch.where(anchor_matching_gt > 1)
报错信息为:
RuntimeError: numel: integer multiplication overflow
不太清楚这个where计算量在哪里,感觉非常的迷惑,希望能指点一下为什么会出现这个问题