dice_loss_for_NLP
dice_loss_for_NLP copied to clipboard
Some question about flat_ input and flat_target.
Suppose I have the following probs and labels (Binary classification):
probs = torch.FloatTensor([[0.3],
[0.8],
[0.2],
[0.7]])
targets = torch.LongTensor([[0],
[1],
[0],
[1]])
Execute the following code:
loss = DiceLoss(alpha=1, smooth=1, with_logits=False, ohem_ratio=0.0, reduction='mean')
output = loss(inputs, targets)
print(output)
No doubt, the code will enter binary_ class().
def _binary_class(self, input, target, mask=None):
flat_input = input.view(-1)
flat_target = target.view(-1).float()
At this time, the shape of flat_input
is:
tensor([0.3000, 0.8000, 0.2000, 0.7000])
The shape of flat_target
is:
tensor([0., 1., 0., 1.])
So far, there is no problem, but when I switched to multi-classification for testing, I found that there was a problem with the shape of flat_input
and flat_target
.
Suppose I have the following probs and labels (Multi classification):
probs = torch.FloatTensor([[0.1,0.8,0.7],
[0.5,0.1,0.6],
[0.7,0.5,0.8],
[0.4,0.6,0.9]])
targets = torch.LongTensor([[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 1, 0]])
Execute the following code:
loss = DiceLoss(alpha=1, smooth=1, with_logits=False, ohem_ratio=0.0, index_label_position=False, reduction='mean')
output = loss(inputs, targets)
print(output)
No doubt, the code will enter multiple_class().
def _multiple_class(self, input, target, logits_size, mask=None):
flat_input = input
flat_target = F.one_hot(target, num_classes=logits_size).float() if self.index_label_position else target.float()
At this time, the shape of flat_input
is:
tensor([0.1000, 0.5000, 0.7000, 0.4000])
The shape of flat_target
is:
tensor([1., 0., 0., 0.])
But after the following code:
loss_idx = self._compute_dice_loss(flat_input_idx.view(-1, 1), flat_target_idx.view(-1, 1))
The shape of flat_input
is:
tensor([[0.1000],
[0.5000],
[0.7000],
[0.4000]])
The shape of flat_target
is:
tensor([[1.],
[0.],
[0.],
[0.]])
I don't understand why when you calculate dice loss, flat_ input
and flat_target
has different shapes.