torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

`ignore_index` for F1 doesn't behave as expected.

Open yukw777 opened this issue 4 years ago • 6 comments

🐛 Bug

F1 doesn't ignore indices properly.

To Reproduce

Run the following code.

import torch
from torchmetrics import F1

f1 = F1(ignore_index=0)
f1(torch.tensor([1, 1, 1, 1, 2, 1, 1]), torch.tensor([0, 0, 1, 1, 2, 0, 0]))

This gives you tensor(0.6000) not tensor(1.0).

Expected behavior

The specified ignore_index should not count towards the F1 score. For example, the above code example should be effectively equivalent to the following:

import torch
from torchmetrics import F1

f1 = F1()
f1(torch.tensor([1, 1, 2]), torch.tensor([1, 1, 2]))

Environment

  • PyTorch Version (e.g., 1.0): 1.10
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source): N/A
  • Python version: 3.8
  • CUDA/cuDNN version: N/A
  • GPU models and configuration: N/A
  • Any other relevant information:

yukw777 avatar Nov 09 '21 17:11 yukw777

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Nov 09 '21 17:11 github-actions[bot]

https://github.com/PyTorchLightning/metrics/blob/fef83a028880f2ad3e0c265b3e5bb8a184805798/torchmetrics/functional/classification/stat_scores.py#L132-L143

It seems like if reduce is not macro it just deletes that particular column, which I don't think is the right thing to do.. rather we should delete the "rows" whose value is ignore_index right?

yukw777 avatar Nov 09 '21 19:11 yukw777

I've had a similar experience using ignore_index with IoU (Jaccard Index), where the IoU value will start at 100.00 and as training progresses the value tends towards 0.

Hommus avatar Jan 07 '22 04:01 Hommus

Since I am not sure about the original intention of ignore_index. To make it works as the expected behavior for now, I make some modifications for the tensor before passing to the torchmetric.F1();

import torch
from torchmetrics import F1

ignore_index = 0
y = torch.tensor([0, 0, 1, 1, 2, 0, 0])
y_hat = torch.tensor([1, 1, 1, 1, 2, 1, 1])
inactive_index = y == ignore_index
y_hat[inactive_index] = ignore_index 

f1 = F1(ignore_index=0)
f1(y_hat, y)

tchayintr avatar Jan 07 '22 04:01 tchayintr

@tchayintr would you be interested in sending a PR and @stancld may help if needed? :rabbit:

Borda avatar Jan 19 '22 21:01 Borda

@Borda Sure. Let me throughly review the code, particularly metrics/torchmetrics/functional/classification/stat_scores.py, before considering possibilities and drafting a PR.

tchayintr avatar Jan 20 '22 10:01 tchayintr

Issue will be fixed by classification refactor: see this issue https://github.com/Lightning-AI/metrics/issues/1001 and this PR https://github.com/Lightning-AI/metrics/pull/1195 for all changes

Small recap: This issue describes that the ignore_index argument is not giving the right result currently in the f1_score metric. This is due to how ignore_index samples are currently accounted for. In the refactor the code has been changed to correctly ignore samples, see example below using the new multiclass_f1_score function:

from torchmetrics.functional import multiclass_f1_score
import torch

preds = torch.tensor([1, 1, 1, 1, 2, 1, 1])
target = torch.tensor([0, 0, 1, 1, 2, 0, 0])

multiclass_f1_score(preds, target, num_classes=3, average="micro", ignore_index=0)  # tensor(1.)

which give the correct result. Issue will be closed when https://github.com/Lightning-AI/metrics/pull/1195 is merged.

SkafteNicki avatar Aug 30 '22 14:08 SkafteNicki