PaddleSeg
PaddleSeg copied to clipboard
Modify the Dice loss
PR types
[Bug fixes]
PR changes
[Models]
Description
The Dice loss in paddleseg.models.losses.dice_loss and paddleseg.models.losses.maskformer_loss is modified based on JDTLoss and segmentation_models.pytorch.
The original Dice loss is incompatible with soft labels. For example, with a ground truth value of 0.5 for a single pixel, it is minimized when the predicted value is 1, which is clearly erroneous. To address this, the intersection term is rewritten as $\frac{|x|_1 + |y|_1 - |x-y|_1}{2}$. This reformulation has been proven to retain equivalence with the original version when the ground truth is binary (i.e. one-hot hard labels). Moreover, since the new version is minimized if and only if the prediction is identical to the ground truth, even when the ground truth include fractional numbers, it resolves the issue with soft labels [1, 2].
In summary, there are three scenarios:
- [Scenario 1] $x$ is nonnegative and $y$ is binary: The new version is the same as the original version.
- [Scenario 2] Both $x$ and $y$ are nonnegative: The new version differs from the original version. The new version is minimized if and only if $x=y$, while the original version may not, making it incorrect.
- [Scenario 3] Either $x$ or $y$ is negative: The new version differs from the original version. The new version is minimized if and only if $x=y$, while the original version may not, making it incorrect.
Due to these differences, particularly in Scenarios 2 and 3, some tests may fail with the new version. The failures are expected since the original version is incorrectly defined for non-binary ground truth.
Example
import paddle
import paddle.nn.functional as F
paddle.seed(0)
b, c, h, w = 4, 3, 32, 32
axis = (0, 2, 3)
pred = F.softmax(paddle.rand((b, c, h, w)), axis=1)
soft_label = F.softmax(paddle.rand((b, c, h, w)), axis=1)
hard_label = paddle.randint(low=0, high=c, shape=(b, h, w))
one_hot_label = paddle.transpose(F.one_hot(hard_label, c), perm=(0, 3, 1, 2))
def dice_old(x, y, axis):
cardinality = paddle.sum(x, axis=axis) + paddle.sum(y, axis=axis)
intersection = paddle.sum(x * y, axis=axis)
return 2 * intersection / cardinality
def dice_new(x, y, axis):
cardinality = paddle.sum(x, axis=axis) + paddle.sum(y, axis=axis)
difference = paddle.sum(paddle.abs(x - y), axis=axis)
intersection = (cardinality - difference) / 2
return 2 * intersection / cardinality
print(dice_old(pred, one_hot_label, axis), dice_new(pred, one_hot_label, axis))
print(dice_old(pred, soft_label, axis), dice_new(pred, soft_label, axis))
print(dice_old(pred, pred, axis), dice_new(pred, pred, axis))
# tensor([0.3356, 0.3308, 0.3319]) tensor([0.3356, 0.3308, 0.3319])
# tensor([0.3326, 0.3323, 0.3340]) tensor([0.8668, 0.8670, 0.8675])
# tensor([0.3505, 0.3512, 0.3513]) tensor([1., 1., 1.])
References
[1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew B. Blaschko. MICCAI 2023.
[2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. NeurIPS 2023.