dice_loss_for_NLP icon indicating copy to clipboard operation
dice_loss_for_NLP copied to clipboard

Some question about flat_ input and flat_target.

Open NEMOlv opened this issue 2 years ago • 0 comments

Suppose I have the following probs and labels (Binary classification):

probs = torch.FloatTensor([[0.3],
                           [0.8],
                           [0.2],
                           [0.7]])

targets = torch.LongTensor([[0],
                            [1],
                            [0],
                            [1]])

Execute the following code:

loss = DiceLoss(alpha=1, smooth=1, with_logits=False, ohem_ratio=0.0, reduction='mean')
output = loss(inputs, targets)
print(output)

No doubt, the code will enter binary_ class().

def _binary_class(self, input, target, mask=None):
        flat_input = input.view(-1)
        flat_target = target.view(-1).float()

At this time, the shape of flat_input is:

tensor([0.3000, 0.8000, 0.2000, 0.7000])

The shape of flat_target is:

tensor([0., 1., 0., 1.])

So far, there is no problem, but when I switched to multi-classification for testing, I found that there was a problem with the shape of flat_input and flat_target .

Suppose I have the following probs and labels (Multi classification):

probs  = torch.FloatTensor([[0.1,0.8,0.7],
                            [0.5,0.1,0.6],
                            [0.7,0.5,0.8],
                            [0.4,0.6,0.9]])

targets = torch.LongTensor([[1, 0, 0],
                            [0, 1, 0],
                            [0, 0, 1],
                            [0, 1, 0]])

Execute the following code:

loss = DiceLoss(alpha=1, smooth=1, with_logits=False, ohem_ratio=0.0, index_label_position=False, reduction='mean')
output = loss(inputs, targets)
print(output)

No doubt, the code will enter multiple_class().

def _multiple_class(self, input, target, logits_size, mask=None):
        flat_input = input
        flat_target = F.one_hot(target, num_classes=logits_size).float() if self.index_label_position else target.float()

At this time, the shape of flat_input is:

tensor([0.1000, 0.5000, 0.7000, 0.4000])

The shape of flat_target is:

tensor([1., 0., 0., 0.])

But after the following code:

loss_idx = self._compute_dice_loss(flat_input_idx.view(-1, 1), flat_target_idx.view(-1, 1))

The shape of flat_input is:

tensor([[0.1000],
        [0.5000],
        [0.7000],
        [0.4000]])

The shape of flat_target is:

tensor([[1.],
        [0.],
        [0.],
        [0.]])

I don't understand why when you calculate dice loss, flat_ input and flat_target has different shapes.

NEMOlv avatar Jan 12 '23 12:01 NEMOlv