torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

Support -100 as ignore_index of Accuracy

Open richarddwang opened this issue 3 years ago • 10 comments

When running the instance Accuracy(ignore_index=-100), it will tell you -100 is not in the classes.

But torch.nn.CrossEntropy(..., ignore_index=-100) is default, and we often train masked language model with label tensor in shape (batch_size, sequence_length) where positions not selected by masking process are filled with -100 to prevent counting loss from non-masking positions.

We can still check ignore_index is not > number of classes, but releasing the limit that ignore_index can't < 0 would be better.

richarddwang avatar Jun 20 '21 03:06 richarddwang

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Jun 20 '21 03:06 github-actions[bot]

Hi @richarddwang, sound good to me. Would you be up for sending a PR?

The relevant check that should be changed: https://github.com/PyTorchLightning/metrics/blob/7af6d13b3c2186aacf5594793317c15a199608bb/torchmetrics/functional/classification/stat_scores.py#L95-L96 (and a similar checks)

SkafteNicki avatar Jun 21 '21 18:06 SkafteNicki

Hi @SkafteNicki, I am not clear about how ignore_index is handled, https://github.com/PyTorchLightning/metrics/blob/7af6d13b3c2186aacf5594793317c15a199608bb/torchmetrics/functional/classification/stat_scores.py#L111-L113 https://github.com/PyTorchLightning/metrics/blob/7af6d13b3c2186aacf5594793317c15a199608bb/torchmetrics/functional/classification/stat_scores.py#L24-L27 Would it still be good when ignore_index is -100 ?

richarddwang avatar Jun 22 '21 01:06 richarddwang

@richarddwang yeah, maybe it is bit more complex than first thought. Could you please provide an example input that you would normally feed into torch.nn.CrossEntropy with ignore_index=-100 just so I can better understand?

SkafteNicki avatar Jun 23 '21 12:06 SkafteNicki

Sure, one that wants to do masked language model will often see something like this.

batch_size, sequence_length  = 4,5
mlm_logits = torch.randn(4,5,128)
labels = torch.tensor([
    [1234,-100,-100,-100,-100],
    [-100,-100,-100,7567,-100]
    [-100,-100,8900,-100,-100]
    [-100,-100,-100,-100,3456]
]) # (4,5), where -100 says it is not selected position for mlm to learn from
loss = nn.CrossEntropyLoss(mlm_logits.view(-1,128), labels.view(-1))
mlm_predictions = mlm_logits.argmax(dim=-1) # (4,5)
self.log('mlm_acc', mlm_acc(mlm_predictions.view(-1), labels.view(-1)))

richarddwang avatar Jun 23 '21 12:06 richarddwang

Can we also handle the case for segmentation with datasets like PASCAL VOC, which has a "void" label of 255 but the number of classes is 21? I tried passing in the number of classes but then I get the error:

The `ignore_index` 255 is not valid for inputs with 21 classes

It looks like there's a baked in assumption to StatScores and torchmetrics.utilities.checks that the targets only contain valid class indices. But in segmentation datasets often there is a "void" label to mark things like boundaries where the label is ambiguous. So the model outputs a prediction tensor of shape N, C, H, W for C classes, and the target tensor has shape N, H, W with values ranging from 0 to C-1, plus the void label. It doesn't appear that TorchMetrics supports this currently but I think it should since PyTorch classes like torch.nn.NLLLoss don't have an issue with this setup.

collinmccarthy avatar Jun 25 '21 17:06 collinmccarthy

Hi all. I wrote a naive solution, which makes the ignored part be incorrect in comparison and then divided by number of unignored. There may be better way, this is just for your reference.

    @torch.no_grad()
    def accuracy(
        logits_or_ids,  # <float>(B,L,C)/<long>(B,L)
        labels,  # <long>(B,L)
        ignore_id: int,
    ):
        if len(logits_or_ids.shape) == len(labels.shape):
            preds = logits_or_ids  # <long>(B,L)
        else:
            preds = logits_or_ids.argmax(dim=-1)  # <long>(B,L)
        ignore = labels == ignore_id
        # make ignored positions incorrect thus out of correction counting
        assert ignore_id != -9999
        preds = preds.masked_fill(ignore, -9999)
        num_correct = (preds == labels).sum()
        num_nonignore = (~ignore).sum()
        return num_correct / num_nonignore

richarddwang avatar Jul 21 '21 02:07 richarddwang

@SkafteNicki how are we doing here? :rabbit:

Borda avatar Aug 08 '21 09:08 Borda

This feature is a must-have for also F1, AUCROC, and any other classification metrics.

TezRomacH avatar Aug 12 '21 16:08 TezRomacH

I wrote a simple wrapper to flat out ignore the targets with ignore_index.

from torchmetrics import functional as FM

def ignore_index(func):
    def ignore_and_call(preds: Tensor, target: Tensor, ignore_target: Optional[int] = None, *args, **kwargs):
        indices = (target != ignore_target).nonzero(as_tuple=False)
        preds_ignored = preds[indices]
        target_ignored = target[indices]
        res = func(*args, **kwargs, preds=preds_ignored, target=target_ignored)
        return res
    return ignore_and_call, func

accuracy = ignore_index(FM.accuracy)

Unfortunately, if done like this, you lose doc string etc. To prevent this, you could e.g. copy the accuracy methods declaration and then call the wrapped method from within.

alexsr avatar Sep 03 '21 13:09 alexsr

Issue will be fixed by classification refactor: see this issue https://github.com/Lightning-AI/metrics/issues/1001 and this PR https://github.com/Lightning-AI/metrics/pull/1195 for all changes

Small recap: This issue ask for support that ignore_index can be outside the [0, num_classes] range for accuracy metrics. After the refactor, all new classification metrics support this. Here is accuracy and jaccard as examples:

from torchmetrics.functional import binary_accuracy, binary_jaccard_index
import torch
preds_prob = torch.tensor([0.2, 0.7, 0.3, 0.3])
target_prob = torch.tensor([1, 1, 0, -1])

binary_accuracy(preds_prob, target_prob, ignore_index=-1)  # tensor(0.6667)
binary_jaccard_index(preds_prob, target_prob, ignore_index=-1)  # tensor(0.5000)

Issue will be closed when https://github.com/Lightning-AI/metrics/pull/1195 is merged.

SkafteNicki avatar Aug 29 '22 13:08 SkafteNicki