torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

`top_k` for `multiclassf1score` is not working correctly

Open eneserdo opened this issue 2 years ago • 4 comments

🐛 Bug

top_k argument of MulticlassF1Score is not working as expected. It suppose to give higher results as top_k increases, but that is not happening sometimes.

According to docs:

top_k (int) – Number of highest probability or logit score predictions considered to find the correct label.

So, it must increase strictly always.

Also normally, when top_k=num_classes, it is expected to give 1 (100%), but that's not happening either.

To Reproduce

Steps to reproduce the behavior...

Code sample
import torch
from torchmetrics.classification import MulticlassF1Score

preds = torch.randn(200, 5).softmax(dim=-1)
target = torch.randint(5, (200,))

f1_val_top1=MulticlassF1Score(num_classes=5, top_k=1, average="macro")
f1_val_top3=MulticlassF1Score(num_classes=5, top_k=3, average="macro")
f1_val_top5=MulticlassF1Score(num_classes=5, top_k=5, average="macro")

print(f1_val_top1(preds, target), f1_val_top3(preds, target), f1_val_top5(preds, target))  

It returns (tensor(0.1774), tensor(0.2740), tensor(0.3318))

AFAI understood from documentation, when I set top_k=5, it must give 1 because there only 5 classes anyway.

More explicitly, I was expected the following two to have the same output:

import torch, functorch
from torchmetrics.classification import MulticlassF1Score
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))

f1_val_top3=MulticlassF1Score(num_classes=5, top_k=3, average="macro")
f1_val_top1=MulticlassF1Score(num_classes=5, top_k=1, average="macro")

pred_top_3 = torch.argsort(preds, dim=1, descending=True)[:, :3]
pred_top_1 = pred_top_3[:, 0]

# This simply changes the incorrect labels with the correct ones, only if correct guesses is in top 3 predictions 
pred_corrected_top3 = torch.where(functorch.vmap(lambda t1, t2: torch.isin(t1, t2))(target, pred_top_3), target, pred_top_1)

print(f1_val_top3(preds, target), f1_val_top1(pred_corrected_top3, target))

But result is different

Environment

  • TorchMetrics 0.11.3 (installed via pip):
  • Python 3.8.16
  • PyTorch 1.12.0

eneserdo avatar Mar 25 '23 11:03 eneserdo

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Mar 25 '23 11:03 github-actions[bot]

Hi I am interested in solving this issue. Can I work on it?

arijitde92 avatar May 25 '23 10:05 arijitde92

@arijitde92 sorry for late reply, sure, you are welcome to take it :purple_heart:

Borda avatar Aug 25 '23 11:08 Borda