torchmetrics
torchmetrics copied to clipboard
Support ignore_index in ConfusionMatrix
🚀 Feature
Add an ignore_index
parameter to the classification metric ConfusionMatrix
, as in Accuracy
, Precision
, Recall
, etc...
Motivation
I use a MetricCollection
for my classifier evaluation, and every other metric in this collection supports ignore_index
except ConfusionMatrix
. This causes a lot of trouble and I don't see a reason for not supporting it.
Pitch
ConfusionMatrix
should support ignore_index
so this should work totally fine:
target = torch.tensor([1, 1, 0, 0, -1])
preds = torch.tensor([0, 1, 0, 0, 1])
confmat = ConfusionMatrix(num_classes=2, ignore_index=-1)
confmat(preds, target)
and output:
tensor([[2, 0],
[1, 1]])
Alternatives
I tried implementing my own IgnoreIndexConfusionMatrix
but I encountered device synchronization issues...
Hi! thanks for your contribution!, great first issue!
Issue will be fixed by classification refactor: see this issue https://github.com/Lightning-AI/metrics/issues/1001 and this PR https://github.com/Lightning-AI/metrics/pull/1195 for all changes
Small recap: Issue ask that ConfusionMatrix
also supports ignore_index
argument known from Accuracy
, Precision
, Recall
. Is now supported in all the new introduced versions e.g. BinaryConfusionMatrix
, MulticlassConfusionMatrix
, MultilabelConfusionMatrix
:
import torch
from torchmetrics.classification import BinaryConfusionMatrix
target = torch.tensor([1, 1, 0, 0, -1])
preds = torch.tensor([0, 1, 0, 0, 1])
confmat = BinaryConfusionMatrix(num_classes=2, ignore_index=-1)
confmat(preds, target)
# tensor([[2, 0],
# [1, 1]])
Issue will be closed when https://github.com/Lightning-AI/metrics/pull/1195 is merged.