`top_k` for `multiclassf1score` is not working correctly
🐛 Bug
top_k argument of MulticlassF1Score is not working as expected. It suppose to give higher results as top_k increases, but that is not happening sometimes.
According to docs:
top_k (int) – Number of highest probability or logit score predictions considered to find the correct label.
So, it must increase strictly always.
Also normally, when top_k=num_classes, it is expected to give 1 (100%), but that's not happening either.
To Reproduce
Steps to reproduce the behavior...
Code sample
import torch
from torchmetrics.classification import MulticlassF1Score
preds = torch.randn(200, 5).softmax(dim=-1)
target = torch.randint(5, (200,))
f1_val_top1=MulticlassF1Score(num_classes=5, top_k=1, average="macro")
f1_val_top3=MulticlassF1Score(num_classes=5, top_k=3, average="macro")
f1_val_top5=MulticlassF1Score(num_classes=5, top_k=5, average="macro")
print(f1_val_top1(preds, target), f1_val_top3(preds, target), f1_val_top5(preds, target))
It returns (tensor(0.1774), tensor(0.2740), tensor(0.3318))
AFAI understood from documentation, when I set top_k=5, it must give 1 because there only 5 classes anyway.
More explicitly, I was expected the following two to have the same output:
import torch, functorch
from torchmetrics.classification import MulticlassF1Score
preds = torch.randn(10, 5).softmax(dim=-1)
target = torch.randint(5, (10,))
f1_val_top3=MulticlassF1Score(num_classes=5, top_k=3, average="macro")
f1_val_top1=MulticlassF1Score(num_classes=5, top_k=1, average="macro")
pred_top_3 = torch.argsort(preds, dim=1, descending=True)[:, :3]
pred_top_1 = pred_top_3[:, 0]
# This simply changes the incorrect labels with the correct ones, only if correct guesses is in top 3 predictions
pred_corrected_top3 = torch.where(functorch.vmap(lambda t1, t2: torch.isin(t1, t2))(target, pred_top_3), target, pred_top_1)
print(f1_val_top3(preds, target), f1_val_top1(pred_corrected_top3, target))
But result is different
Environment
- TorchMetrics 0.11.3 (installed via
pip): - Python 3.8.16
- PyTorch 1.12.0
Hi! thanks for your contribution!, great first issue!
Hi I am interested in solving this issue. Can I work on it?
@arijitde92 sorry for late reply, sure, you are welcome to take it :purple_heart: