torchmetrics
torchmetrics copied to clipboard
JaccardIndex then multilabel=True not working
🐛 Bug
JaccardIndex is not correctly coded when multilabel=True and average is different than "none".
To Reproduce
Instantiate JaccardIndex(..., multilabel=True, average="micro") and call it as usually with a multilabel classification data.
Code sample
target = torch.randint(0, 2, (8, 3))
preds = torch.rand(8, 3)
ji = torchmetrics.classification.JaccardIndex(num_classes=3, multilabel=True, average="micro")
ji(preds, target)
Environment
- TorchMetrics version 0.9.3
Additional context
When multilabel=True, you call _jaccard_from_confmat and then access to index 1 (this is correct) https://github.com/Lightning-AI/metrics/blob/v0.9.3/torchmetrics/classification/jaccard.py#L117
- About micro: It returns a scalar (sums are global) https://github.com/Lightning-AI/metrics/blob/ff61c482e5157b43e647565fa0020a4ead6e9d61/torchmetrics/functional/classification/jaccard.py#L85
- About macro: It returns a scalar (mean is global) https://github.com/Lightning-AI/metrics/blob/ff61c482e5157b43e647565fa0020a4ead6e9d61/torchmetrics/functional/classification/jaccard.py#L80
- About weighted: It returns a scalar (sum is global) https://github.com/Lightning-AI/metrics/blob/ff61c482e5157b43e647565fa0020a4ead6e9d61/torchmetrics/functional/classification/jaccard.py#L91
Possible implementation snippet
With this IoU (jaccard) base implementation, you can easily organize different combinations (multilabel=True + macro, multilabel=False + micro, etc)
def _compute_iou(cm: torch.Tensor) -> torch.Tensor:
intersection = cm.diagonal(dim1=-2, dim2=-1)
union = torch.sum(cm, dim=-1) + torch.sum(cm, dim=-2) - intersection
return intersection.float() / union.float()
For example:
Micro + multilabel=True
# mlcm: (3, 2, 2)
iou = _compute_iou(mlcm.sum(0))[1]
Macro + multilabel=True
# mlcm: (3, 2, 2)
iou = _compute_iou(mlcm)[:, 1].mean()
Macro + multilabel=False
# cm: (3, 3)
iou = _compute_iou(cm).mean()
And so on.
Hi! thanks for your contribution!, great first issue!
Issue will be fixed by classification refactor: see this issue https://github.com/Lightning-AI/metrics/issues/1001 and this PR https://github.com/Lightning-AI/metrics/pull/1195 for all changes
Small recap: This issue describes that jaccard_index is wrongly calculated in the multilabel setting. This is simply due to a wrong implementation. Issue have been fixed in the refactor such that everything should be right (our implementation is better tested against sklearn now). Only difference is that instead of using jaccard_index the specialized version multilabel_jaccard_index should be used:
from torchmetrics.functional import multilabel_jaccard_index
import torch
target = torch.randint(0, 2, (8, 3))
preds = torch.rand(8, 3)
multilabel_jaccard_index(preds, target, num_labels=3, average="micro") # tensor(0.2632)
multilabel_jaccard_index(preds, target, num_labels=3, average="micro") # tensor(0.2762)
multilabel_jaccard_index(preds, target, num_labels=3, average=None) # tensor([0.1429, 0.2857, 0.4000])
which give the correct result. Issue will be closed when https://github.com/Lightning-AI/metrics/pull/1195 is merged.