UnitBox_TF icon indicating copy to clipboard operation
UnitBox_TF copied to clipboard

reproduced IOUloss by pytorch

Open chituma110 opened this issue 5 years ago • 0 comments

I have reproduced the IOUloss by pytorch in mmdetection and replace the smmothl1 by it. However I got NaN.

the code is below:

def IOULoss(pred, target,weight, avg_factor=None): """ :param pred: the estimate position :param target: the ground truth position :param weight: weight of each loss :param avg_factor: number of examples :return: the IoU loss """ # the estimate position # print('size', pred.size()) # print('label size', label.size()) # print('label is ', target) xt, xb, xl, xr = torch.split(pred, split_size_or_sections=1, dim=1)

# the ground truth position'
gt, gb, gl, gr = torch.split(target, split_size_or_sections=1, dim=1)

# compute the bounding box size
X = (xt + xb) * (xl + xr)
G = (gt + gb) * (gl + gr)

# compute the IOU
Ih = torch.min(xt, gt) + torch.min(xb, gb)
Iw = torch.min(xl, gl) + torch.min(xr, gr)

_EPSILON = 10e-8

# I = torch.multiply(Ih, Iw, name="intersection")
I = Ih * Iw
U = X + G - I + _EPSILON

IoU = torch.div(I, U)

L = torch.where((1-torch.gt(gt, 0.01)),
                torch.zeros_like(xt),
             -torch.log(IoU + _EPSILON))
print('loss: ', torch.sum(L*weight)[None] / avg_factor)
return torch.sum(L*weight)[None] / avg_factor

chituma110 avatar May 25 '19 15:05 chituma110