torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

JaccardIndex then multilabel=True not working

Open claverru opened this issue 3 years ago • 1 comments

🐛 Bug

JaccardIndex is not correctly coded when multilabel=True and average is different than "none".

To Reproduce

Instantiate JaccardIndex(..., multilabel=True, average="micro") and call it as usually with a multilabel classification data.

Code sample

target = torch.randint(0, 2, (8, 3))
preds = torch.rand(8, 3)

ji = torchmetrics.classification.JaccardIndex(num_classes=3, multilabel=True, average="micro")

ji(preds, target)

Environment

  • TorchMetrics version 0.9.3

Additional context

When multilabel=True, you call _jaccard_from_confmat and then access to index 1 (this is correct) https://github.com/Lightning-AI/metrics/blob/v0.9.3/torchmetrics/classification/jaccard.py#L117

  • About micro: It returns a scalar (sums are global) https://github.com/Lightning-AI/metrics/blob/ff61c482e5157b43e647565fa0020a4ead6e9d61/torchmetrics/functional/classification/jaccard.py#L85
  • About macro: It returns a scalar (mean is global) https://github.com/Lightning-AI/metrics/blob/ff61c482e5157b43e647565fa0020a4ead6e9d61/torchmetrics/functional/classification/jaccard.py#L80
  • About weighted: It returns a scalar (sum is global) https://github.com/Lightning-AI/metrics/blob/ff61c482e5157b43e647565fa0020a4ead6e9d61/torchmetrics/functional/classification/jaccard.py#L91

Possible implementation snippet

With this IoU (jaccard) base implementation, you can easily organize different combinations (multilabel=True + macro, multilabel=False + micro, etc)

def _compute_iou(cm: torch.Tensor) -> torch.Tensor:
    intersection = cm.diagonal(dim1=-2, dim2=-1)
    union = torch.sum(cm, dim=-1) + torch.sum(cm, dim=-2) - intersection
    return intersection.float() / union.float()

For example:

Micro + multilabel=True

# mlcm: (3, 2, 2)
iou = _compute_iou(mlcm.sum(0))[1]

Macro + multilabel=True

# mlcm: (3, 2, 2)
iou = _compute_iou(mlcm)[:, 1].mean()

Macro + multilabel=False

# cm: (3, 3)
iou = _compute_iou(cm).mean()

And so on.

claverru avatar Aug 04 '22 11:08 claverru

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Aug 04 '22 11:08 github-actions[bot]

Issue will be fixed by classification refactor: see this issue https://github.com/Lightning-AI/metrics/issues/1001 and this PR https://github.com/Lightning-AI/metrics/pull/1195 for all changes

Small recap: This issue describes that jaccard_index is wrongly calculated in the multilabel setting. This is simply due to a wrong implementation. Issue have been fixed in the refactor such that everything should be right (our implementation is better tested against sklearn now). Only difference is that instead of using jaccard_index the specialized version multilabel_jaccard_index should be used:

from torchmetrics.functional import multilabel_jaccard_index
import torch

target = torch.randint(0, 2, (8, 3))
preds = torch.rand(8, 3)

multilabel_jaccard_index(preds, target, num_labels=3, average="micro") # tensor(0.2632)
multilabel_jaccard_index(preds, target, num_labels=3, average="micro") # tensor(0.2762)
multilabel_jaccard_index(preds, target, num_labels=3, average=None)  # tensor([0.1429, 0.2857, 0.4000])

which give the correct result. Issue will be closed when https://github.com/Lightning-AI/metrics/pull/1195 is merged.

SkafteNicki avatar Aug 28 '22 12:08 SkafteNicki