IndexError: index -100 is out of bounds for dimension 1 with size 2
🐛 Bug
When average=None for the classification metric, negative ignore_index can't be properly ignored. Negative ignore_index works fine when the average is set to other values.
To Reproduce
run the code attached below
from torchmetrics import Precision
import torch
target = torch.tensor([1, 1])
preds = torch.tensor([[-0.0027, -0.0023],
[-0.0015, -0.0098]])
precision = Precision(2, ignore_index=-100)
precision.update(preds=preds, target=target)
precision = Precision(2, average="none", ignore_index=-100)
precision.update(preds=preds, target=target)
Expected behavior
IndexError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/torchmetrics/metric.py in wrapped_func(*args, **kwargs) 393 with torch.set_grad_enabled(self._enable_grad): 394 try: --> 395 update(*args, **kwargs) 396 except RuntimeError as err: 397 if "Expected all tensors to be on" in str(err):
/usr/local/lib/python3.7/dist-packages/torchmetrics/classification/stat_scores.py in update(self, preds, target) 710 top_k=self.top_k, 711 multiclass=self.multiclass, --> 712 ignore_index=self.ignore_index, 713 ) 714
/usr/local/lib/python3.7/dist-packages/torchmetrics/functional/classification/stat_scores.py in _stat_scores_update(preds, target, reduce, mdmc_reduce, num_classes, top_k, threshold, multiclass, ignore_index, mode) 977 # Take care of ignore_index 978 if ignore_index is not None and reduce == "macro" and not _negative_index_dropped: --> 979 tp[..., ignore_index] = -1 980 fp[..., ignore_index] = -1 981 tn[..., ignore_index] = -1
IndexError: index -100 is out of bounds for dimension 1 with size 2
Environment
torchmetrics=0.10.0
Hi! thanks for your contribution!, great first issue!
Hi @senzeyu. Thanks for raising this issue. v0.10 brought large changes to the classification package which you can read more about here: https://devblog.pytorchlightning.ai/torchmetrics-v0-10-large-changes-to-classifications-b162b674e7e1
The bug you describe here have been corrected in the new interface and is not completely implemented with the old interface which is in a deprecation phase. Therefore, the two examples you have provided I would recommend that you instead do the following:
from torchmetrics.classification import BinaryPrecision, MulticlassPrecision
# If you are just interested in the precision in the binary case use this formulation
precision = BinaryPrecision(ignore_index=-100)
precision.update(preds.argmax(dim=1), target)
# If you are interested in the precision for each of the two classes use this formulation
precision = MulticlassPrecision(num_classes=2, average=None, ignore_index=-100)
precision.update(preds, target)
Closing issue.