torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

`top_k` for `MulticlassRecall` is not working as expected

Open c23996 opened this issue 10 months ago • 2 comments

🐛 Bug

The MulticlassRecall function with a top_k > 1 and the average parameter set to "macro" is not behaving as anticipated. Ideally, as top_k increases, the results should increase. However, on certain occasions, this isn't the case.

To Reproduce

code sample
import torch
from torchmetrics.classification import MulticlassRecall

num_classes = 200
preds = torch.randn(5, num_classes).softmax(dim=-1)
target = torch.randint(num_classes, (5,))

recall_val_top1=MulticlassRecall(num_classes=num_classes, top_k=1, average="macro")
recall_val_top5=MulticlassRecall(num_classes=num_classes, top_k=5, average="macro")
recall_val_top10=MulticlassRecall(num_classes=num_classes, top_k=10, average="macro")
recall_val_top100=MulticlassRecall(num_classes=num_classes, top_k=100, average="macro")

print(recall_val_top1(preds, target), recall_val_top5(preds, target), recall_val_top10(preds, target),recall_val_top100(preds, target)) ```

it returns tensor(0.) tensor(0.0357) tensor(0.0213) tensor(0.0154)

Expected behavior

The results is expected to rise as k grows.

Environment

  • TorchMetrics version: 1.2.1 (and how you installed: pip):
  • Python & PyTorch Version: 3.10.0 & 1.12.1
  • Any other relevant information such as OS (e.g., Linux): OS

Additional context

I checked the function _adjust_weights_safe_divide where it calculates the recall for function _precision_recall_reduce and am unsure of this snippets:

      weights = torch.ones_like(score)
        if not multilabel:
            weights[tp + fp + fn == 0] = 0.0
      return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)

In a multiclass scenario, when calculating top-k results, the number of false positives (fp) tends to increase with a higher value of k. This, in turn, augments the value of weights.sum(-1, keepdim=True) and consequently reduces the final recall@k. Also wondering when calculate macro avg recall, should it be weights[tp + fn == 0] = 0.0 ?

c23996 avatar Oct 11 '23 16:10 c23996

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Oct 11 '23 16:10 github-actions[bot]

top_k for MulticlassAccuracy doesn't work as expected either.

iuhgnor avatar Nov 18 '23 07:11 iuhgnor