super-gradients
super-gradients copied to clipboard
MaskAttentionLoss in DiceCEEdgeLoss doesn't handle images without any edges
Describe the bug
Training models that use DiceCEEdgeLoss
results in NaN loss on images that only contain one semantic class. The edge_target
becomes a tensor filled with zeros because there are no edges in the image:
https://github.com/Deci-AI/super-gradients/blob/aa2745450e4c2617847874eb7f54bcab4ebb8f10/src/super_gradients/training/losses/dice_ce_edge_loss.py#L101-L103
Then, when computing the MaskAttentionLoss
, mask_loss
is a tensor filled with zeros, gets reassigned to an empty tensor, and, finally, computing the mean of an empty tensor results in NaN.
https://github.com/Deci-AI/super-gradients/blob/aa2745450e4c2617847874eb7f54bcab4ebb8f10/src/super_gradients/training/losses/mask_loss.py#L45-L47
To Reproduce
I've written a new test in tests/unit_tests/mask_loss_test.py
that reproduces the problem.
def test_with_cross_entropy_loss_maskless(self):
"""
Test case with mask filled with zeros, corresponding to a scenario without
attention. It's expected that the mask doesn't contribute to the loss.
This scenario may happen when using edge masks on an image without
edges - there's only one semantic region in the whole image.
Shapes: predict [BxCxHxW], target [BxHxW], mask [Bx1xHxW]
"""
predict = torch.randn(self.batch, self.num_classes, self.img_size, self.img_size)
target = self._get_default_target_tensor()
# Create a mask filled with zeros to disable the attention component
mask = self._get_default_mask_tensor() * 0.0
loss_weigths = [1.0, 0.5]
ce_crit = nn.CrossEntropyLoss(reduction="none")
mask_ce_crit = MaskAttentionLoss(criterion=ce_crit, loss_weights=loss_weigths)
# expected result - no contribution from mask
ce_loss = ce_crit(predict, target)
expected_loss = ce_loss.mean() * loss_weigths[0]
# mask ce loss result
loss = mask_ce_crit(predict, target, mask)
self._assertion_torch_values(expected_loss, loss)
Running this test results in:
AssertionError: False is not true : Unequal torch tensors: excepted: 1.7192925214767456, found: nan
Expected behavior
A mask filed with zeros should "disable" attention. Thus, the mask should not contribute to the loss.
Environment:
- 3.0.7
Additional context
Can be fixed by checking if mask_loss
is NaN and setting it to 0 instead. Like this:
mask_loss = mask_loss if not mask_loss.isnan() else mask_loss.new_tensor(0.0)
@davidtvs thanks for reporting this bug, great catch!
Thanks for setting up the unit test for this use case, one thing I would comment on, is instead of checking whether the reduced mask_loss
is nan
, which can results from other unexpected reasons, it's better IMO to check before reduction whether the tensor is an empty tensor which indicate directly that the mask is empty.
mask_loss = mask_loss[mask == 1] # consider only mask samples for mask loss computing
if mask_loss.numel() == 0:
mask_loss = torch.tensor(0.)
We highly encourage new contributors, feel free to open a PR if you find the time to do so. If not we'll push the fix ASAP.
This issue is fixed in master branch and will be available in the next version release. https://github.com/Deci-AI/super-gradients/pull/982
Thanks again @davidtvs for reporting the issue.