torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

`BinaryAccuracy()` sometimes gives incorrect answers due to non-deterministic sigmoiding

Open idc9 opened this issue 1 year ago • 2 comments

🐛 Bug

torchmetrics.classification.BinaryAccuracy will apply a sigmoid to some inputs but not others leading to incorrect behavior.

Details

The current behavior of BinaryAccuracy() is to apply a sigmoid transformation if the inputs are outside of [0, 1] before binarizing

If preds is a floating point tensor with values outside [0,1] range we consider the input to be logits and will auto apply sigmoid per element.

i.e. y_hat = 1(sigmoid(z) >= threshold) if z outside [0, 1] y_hat = 1(z >= threshold) if z inside [0, 1]

I assume z inside [0, 1] is checked for then entire batch (i.e. if one element of the batch is outside [0, 1] then we apply the sigmoid to everyone).

This will cause silent errors. In particular, if the user inputs logits then they expect the logits to always be sigmoided. However, it is totally possible for all of the logits to lie in [0, 1] for some batches in which case the input will not be sigmoided which will cause incorrect thresholding.

To Reproduce

Here is a simple example. Support our network outputs logits.

from torchmetrics.classification import BinaryAccuracy
from scipy.special import expit # expit = sigmoid
import numpy as np
import torch

This example should lead to a correct prediction

probability_thresh = 0.5 
logits = np.array([0.49]) # network output
target = np.array([1])

# logits of 0.49 give a probability of 0.62 indicating class 1, the correct prediction
expit(logits)
array([0.62010643])
int(expit(logits) >= probability_thresh) == target
True

BinaryAccuracy() however thinks it's an incorrect prediction~

# torchmetrics, however, thinks we have the inccorect prediction because it does NOT sigmoid the logits
ba = BinaryAccuracy(threshold=probability_thresh) 
ba.forward(preds=torch.tensor(logits), target=torch.tensor(target))
tensor(0.)

Suggested Fix

I suggest adding an argument indicating whether or not the input predictions are sigmoided so the inputs are either always sigmoided or never sigmoided

idc9 avatar Mar 09 '23 17:03 idc9

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Mar 09 '23 17:03 github-actions[bot]

Looks like a similar thing happens in MultilabelAccuracy (https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html)

idc9 avatar Mar 09 '23 17:03 idc9