BiSeNet icon indicating copy to clipboard operation
BiSeNet copied to clipboard

Right Dice Loss?

Open hitsz-zuoqi opened this issue 4 years ago • 0 comments

class DiceLoss(nn.Module): def __init__(self): super().__init__() self.epsilon = 1e-5

def forward(self, output, target):
    # print(output.shape)
    # print(target.shape)

    assert output.size() == target.size(), "'input' and 'target' must have the same shape"
    # 在classes上做softmax
    output = F.softmax(output, dim=1)
    # 打平tensor
    output = flatten(output) # [num_classes,B*H*W]
    target = flatten(target) # [num_classes,B*H*W]
    # intersect = (output * target).sum(-1).sum() + self.epsilon
    # denominator = ((output + target).sum(-1)).sum() + self.epsilon

    intersect = (output * target).sum(-1)
    denominator = (output + target).sum(-1)
    # dice --(0-0.5)
    dice = intersect / denominator
    dice = torch.mean(dice)
    # 1-dice (0.5,1)???
    return 1 - dice
    # return 1 - 2. * intersect / denominator

double the intersection over union?

hitsz-zuoqi avatar Dec 03 '20 05:12 hitsz-zuoqi