torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

Fix sigmoid overflow for large logits causing incorrect AUROC results

Open Copilot opened this issue 5 months ago • 0 comments

Description

Fixes an issue where binary_auroc and other classification metrics return incorrect results when logits are very large (>16.7 for float32, >36.7 for float64). The sigmoid function overflows to exactly 1.0 for all such values, losing the ranking information needed for AUROC calculation.

Problem

When all logits are in a large range (e.g., 97-100), naive sigmoid application causes numerical overflow:

import torch
from torchmetrics.functional.classification import binary_auroc

preds = torch.tensor([98.0950, 98.4612, 98.1145, 98.1506, 97.6037, 98.9425, 
                      99.2644, 99.5014, 99.7280, 99.6595, 99.6931, 99.4667, 
                      99.9623, 99.8949, 99.8768])
labels = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

# Before fix: returns 0.5 (random guessing)
# After fix: returns 0.9286 (correct)
binary_auroc(preds, labels)

The issue occurs because sigmoid(x) for x > 16.7 evaluates to exactly 1.0 in float32, making all predictions indistinguishable and destroying the ranking information that AUROC depends on.

Solution

Modified normalize_logits_if_needed in src/torchmetrics/utilities/compute.py to apply numerically stable sigmoid when needed:

  • Conditional stabilization: Only applies when min(logits) > 15, indicating all values will overflow
  • Preserves ranking: Subtracts max value before sigmoid, maintaining relative ordering since sigmoid is monotonic
  • Avoids artificial ties: Does not apply stabilization to mixed-range logits (e.g., -5 to 100) where it would create spurious ties
  • Backward compatible: Normal-range logits use standard sigmoid, maintaining existing behavior

Changes

  • Updated normalize_logits_if_needed() to check both min and max values before stabilization
  • Added comprehensive regression test covering:
    • Original issue case (logits 97-100)
    • Very large logits (200+)
    • Mixed range logits (-5 to 100)

Testing

All existing tests pass:

  • ✅ 92 binary AUROC tests passed
  • ✅ 30 precision-recall curve tests passed
  • ✅ 30 stat_scores tests passed
  • ✅ New regression test with 3 cases added

Closes #XXXX

Original prompt

This section details on the original issue you should resolve

<issue_title>torchmetrics.functional.classification.binary_auroc gives wrong results when logits are large</issue_title> <issue_description>## 🐛 Bug

torchmetrics.functional.classification.binary_auroc always gives 0.5 when all logits are large. This seems to be caused by a floating point precision error with sigmoid.

To Reproduce

Code sample
import torch
import torchmetrics.functional.classification

preds = torch.tensor([98.0950, 98.4612, 98.1145, 98.1506, 97.6037, 98.9425, 99.2644, 99.5014, 99.7280, 99.6595, 99.6931, 99.4667, 99.9623, 99.8949, 99.8768])
labels = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

torchmetrics.functional.classification.binary_auroc(preds, labels)

Output:

tensor(0.5000)

Expected behavior

AUROC of the above example should be 0.9286, as computed by sklearn.

import sklearn.metrics

sklearn.metrics.roc_auc_score(labels, preds)

Output:

0.9285714285714286

Environment

  • Windows 11 24H2
  • Python version 3.10.11
  • TorchMetrics version 1.4.3
  • PyTorch version 2.4.1+cu124

Additional context

This appears to be a problem of floating point precision with sigmoid at line 185 in function _binary_precision_recall_curve_format in file torchmetrics/src/torchmetrics/functional/classification/precision_recall_curve.py.

I extracted all the necessary functions and made a miniature binary_auroc function that uses exactly the same algorithm (works for the above example, did not test for other examples):

def binary_auroc(
    preds: torch.Tensor,
    target: torch.Tensor,
) -> torch.Tensor:
    preds = preds.sigmoid()
    print(preds)

    desc_score_indices = torch.argsort(preds, descending=True)
    preds = preds[desc_score_indices]
    target = target[desc_score_indices]
    # print(preds, target)

    # pred typically has many tied values. Here we extract
    # the indices associated with the distinct values. We also
    # concatenate a value for the end of the curve.
    distinct_value_indices = torch.nonzero(preds[1:] - preds[:-1], as_tuple=True)[0]
    # print(distinct_value_indices)
    threshold_idxs = torch.nn.functional.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1)
    # print(threshold_idxs)
    tps = torch.cumsum(target, dim=0)[threshold_idxs]
    fps = 1 + threshold_idxs - tps
    # print(tps, fps)

    # Add an extra threshold position to make sure that the curve starts at (0, 0)
    tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps])
    fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps])
    tpr = tps / tps[-1]
    fpr = fps / fps[-1]
    # print(fpr, tpr)
    return torch.trapezoid(tpr, fpr, dim=-1)

Output:

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

tensor(0.5000)

preds = preds.sigmoid() converts all logits to 1 as if all the logits are the same, which is not the case. The maximum magnitude of a logit must be less than 36.74 for double or 16.64 for float32 to avoid being converted to exactly 1.

Suggested fix

It's probably a good idea to scale the raw logits before sigmoid, something like below:

preds /= torch.max(torch.abs(preds)) # scales max element to 1
preds = preds.sigmoid()

All functions that applies sigmoid to raw ogits will need such a fix.</issue_description>

Comments on the Issue (you are @copilot in this section)

@Borda > ### Suggested fix > It's probably a good idea to scale the raw logits before sigmoid, something like below: > ```py > preds /= torch.max(torch.abs(preds)) # scales max element to 1 > preds = preds.sigmoid() > ``` > All functions that applies sigmoid to raw ogits will need such a fix.

That sounds reasonable to me... @SkafteNicki your thoughts? 🤔 </comment_new>

Fixes Lightning-AI/torchmetrics#2819

💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.


📚 Documentation preview 📚: https://torchmetrics--3283.org.readthedocs.build/en/3283/

Copilot avatar Oct 04 '25 20:10 Copilot