UnitBox_TF
UnitBox_TF copied to clipboard
reproduced IOUloss by pytorch
I have reproduced the IOUloss by pytorch in mmdetection and replace the smmothl1 by it. However I got NaN.
the code is below:
def IOULoss(pred, target,weight, avg_factor=None): """ :param pred: the estimate position :param target: the ground truth position :param weight: weight of each loss :param avg_factor: number of examples :return: the IoU loss """ # the estimate position # print('size', pred.size()) # print('label size', label.size()) # print('label is ', target) xt, xb, xl, xr = torch.split(pred, split_size_or_sections=1, dim=1)
# the ground truth position'
gt, gb, gl, gr = torch.split(target, split_size_or_sections=1, dim=1)
# compute the bounding box size
X = (xt + xb) * (xl + xr)
G = (gt + gb) * (gl + gr)
# compute the IOU
Ih = torch.min(xt, gt) + torch.min(xb, gb)
Iw = torch.min(xl, gl) + torch.min(xr, gr)
_EPSILON = 10e-8
# I = torch.multiply(Ih, Iw, name="intersection")
I = Ih * Iw
U = X + G - I + _EPSILON
IoU = torch.div(I, U)
L = torch.where((1-torch.gt(gt, 0.01)),
torch.zeros_like(xt),
-torch.log(IoU + _EPSILON))
print('loss: ', torch.sum(L*weight)[None] / avg_factor)
return torch.sum(L*weight)[None] / avg_factor