torchmetrics
torchmetrics copied to clipboard
Support -100 as ignore_index of Accuracy
When running the instance Accuracy(ignore_index=-100)
, it will tell you -100
is not in the classes.
But torch.nn.CrossEntropy(..., ignore_index=-100)
is default, and we often train masked language model with label tensor in shape (batch_size, sequence_length) where positions not selected by masking process are filled with -100 to prevent counting loss from non-masking positions.
We can still check ignore_index
is not > number of classes, but releasing the limit that ignore_index
can't < 0 would be better.
Hi! thanks for your contribution!, great first issue!
Hi @richarddwang, sound good to me. Would you be up for sending a PR?
The relevant check that should be changed: https://github.com/PyTorchLightning/metrics/blob/7af6d13b3c2186aacf5594793317c15a199608bb/torchmetrics/functional/classification/stat_scores.py#L95-L96 (and a similar checks)
Hi @SkafteNicki, I am not clear about how ignore_index
is handled,
https://github.com/PyTorchLightning/metrics/blob/7af6d13b3c2186aacf5594793317c15a199608bb/torchmetrics/functional/classification/stat_scores.py#L111-L113
https://github.com/PyTorchLightning/metrics/blob/7af6d13b3c2186aacf5594793317c15a199608bb/torchmetrics/functional/classification/stat_scores.py#L24-L27
Would it still be good when ignore_index is -100 ?
@richarddwang yeah, maybe it is bit more complex than first thought. Could you please provide an example input that you would normally feed into torch.nn.CrossEntropy
with ignore_index=-100
just so I can better understand?
Sure, one that wants to do masked language model will often see something like this.
batch_size, sequence_length = 4,5
mlm_logits = torch.randn(4,5,128)
labels = torch.tensor([
[1234,-100,-100,-100,-100],
[-100,-100,-100,7567,-100]
[-100,-100,8900,-100,-100]
[-100,-100,-100,-100,3456]
]) # (4,5), where -100 says it is not selected position for mlm to learn from
loss = nn.CrossEntropyLoss(mlm_logits.view(-1,128), labels.view(-1))
mlm_predictions = mlm_logits.argmax(dim=-1) # (4,5)
self.log('mlm_acc', mlm_acc(mlm_predictions.view(-1), labels.view(-1)))
Can we also handle the case for segmentation with datasets like PASCAL VOC, which has a "void" label of 255 but the number of classes is 21? I tried passing in the number of classes but then I get the error:
The `ignore_index` 255 is not valid for inputs with 21 classes
It looks like there's a baked in assumption to StatScores
and torchmetrics.utilities.checks
that the targets only contain valid class indices. But in segmentation datasets often there is a "void" label to mark things like boundaries where the label is ambiguous. So the model outputs a prediction tensor of shape N, C, H, W
for C
classes, and the target tensor has shape N, H, W
with values ranging from 0 to C-1, plus the void label. It doesn't appear that TorchMetrics supports this currently but I think it should since PyTorch classes like torch.nn.NLLLoss
don't have an issue with this setup.
Hi all. I wrote a naive solution, which makes the ignored part be incorrect in comparison and then divided by number of unignored. There may be better way, this is just for your reference.
@torch.no_grad()
def accuracy(
logits_or_ids, # <float>(B,L,C)/<long>(B,L)
labels, # <long>(B,L)
ignore_id: int,
):
if len(logits_or_ids.shape) == len(labels.shape):
preds = logits_or_ids # <long>(B,L)
else:
preds = logits_or_ids.argmax(dim=-1) # <long>(B,L)
ignore = labels == ignore_id
# make ignored positions incorrect thus out of correction counting
assert ignore_id != -9999
preds = preds.masked_fill(ignore, -9999)
num_correct = (preds == labels).sum()
num_nonignore = (~ignore).sum()
return num_correct / num_nonignore
@SkafteNicki how are we doing here? :rabbit:
This feature is a must-have for also F1
, AUCROC
, and any other classification metrics.
I wrote a simple wrapper to flat out ignore the targets with ignore_index
.
from torchmetrics import functional as FM
def ignore_index(func):
def ignore_and_call(preds: Tensor, target: Tensor, ignore_target: Optional[int] = None, *args, **kwargs):
indices = (target != ignore_target).nonzero(as_tuple=False)
preds_ignored = preds[indices]
target_ignored = target[indices]
res = func(*args, **kwargs, preds=preds_ignored, target=target_ignored)
return res
return ignore_and_call, func
accuracy = ignore_index(FM.accuracy)
Unfortunately, if done like this, you lose doc string etc. To prevent this, you could e.g. copy the accuracy methods declaration and then call the wrapped method from within.
Issue will be fixed by classification refactor: see this issue https://github.com/Lightning-AI/metrics/issues/1001 and this PR https://github.com/Lightning-AI/metrics/pull/1195 for all changes
Small recap: This issue ask for support that ignore_index
can be outside the [0, num_classes] range for accuracy metrics. After the refactor, all new classification metrics support this. Here is accuracy and jaccard as examples:
from torchmetrics.functional import binary_accuracy, binary_jaccard_index
import torch
preds_prob = torch.tensor([0.2, 0.7, 0.3, 0.3])
target_prob = torch.tensor([1, 1, 0, -1])
binary_accuracy(preds_prob, target_prob, ignore_index=-1) # tensor(0.6667)
binary_jaccard_index(preds_prob, target_prob, ignore_index=-1) # tensor(0.5000)
Issue will be closed when https://github.com/Lightning-AI/metrics/pull/1195 is merged.