BiSeNet
BiSeNet copied to clipboard
Right Dice Loss?
class DiceLoss(nn.Module): def __init__(self): super().__init__() self.epsilon = 1e-5
def forward(self, output, target):
# print(output.shape)
# print(target.shape)
assert output.size() == target.size(), "'input' and 'target' must have the same shape"
# 在classes上做softmax
output = F.softmax(output, dim=1)
# 打平tensor
output = flatten(output) # [num_classes,B*H*W]
target = flatten(target) # [num_classes,B*H*W]
# intersect = (output * target).sum(-1).sum() + self.epsilon
# denominator = ((output + target).sum(-1)).sum() + self.epsilon
intersect = (output * target).sum(-1)
denominator = (output + target).sum(-1)
# dice --(0-0.5)
dice = intersect / denominator
dice = torch.mean(dice)
# 1-dice (0.5,1)???
return 1 - dice
# return 1 - 2. * intersect / denominator
double the intersection over union?