self-adj-dice icon indicating copy to clipboard operation
self-adj-dice copied to clipboard

Ignore classes such as Padding?

Open PonteIneptique opened this issue 4 years ago • 5 comments

Hi there ! First, thanks for the code :) I am gonna try Dice Loss for a NLP Project I am contributing to and I was wondering if you felt that something such as ignore_index= from cross_entropy is useless in the context of Dice ? Thanks in advance for taking the time to reply :+1:

PonteIneptique avatar Dec 07 '20 07:12 PonteIneptique

@PonteIneptique Hi!

Totally makes sense. I assume, you don't need to take padding tokens into account when calculating the loss value.

fursovia avatar Dec 07 '20 15:12 fursovia

How would you go about that ? Do it before invoking DiceLoss or modifying DiceLoss to accept index to ignore ?

PonteIneptique avatar Dec 07 '20 15:12 PonteIneptique

Maybe something like?

def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
    ignore_condition = torch.ne(targets, self.ignore_index)
    logits = logits[ignore_condition]
    targets = targets[ignore_condition]
    ...

fursovia avatar Dec 07 '20 16:12 fursovia

So you'd do it outside of your module ? This is basically my question :) Thanks for the quick answers though ! :)

PonteIneptique avatar Dec 07 '20 16:12 PonteIneptique

Yes, outside the module. Actually, the above example should have used tokens instead of targets to obtain the mask (ignore_condition).

import torch

PADDING_IDX = 0
VOCAB_SIZE = 1000
BATCH_SIZE = 128
SEQ_LENGTH = 40

tokens = torch.randint(0, VOCAB_SIZE, size=(BATCH_SIZE, SEQ_LENGTH))
ignore_condition = torch.ne(tokens, PADDING_IDX)
logits = my_model()

targets = torch.masked_select(targets, ignore_condition)
logits = torch.masked_select(logits, ignore_condition)

loss = criterion(targets, logits)

There is also another way if you would like to do it inside the module. You can add ignore_loss_on_o_tags argument. If True, we compute the loss only for actual spans in tags, and not on O tokens. E.g. all tokens with no tag will be skipped (including the padding token)

fursovia avatar Dec 07 '20 16:12 fursovia