Fix sigmoid overflow for large logits causing incorrect AUROC results
Description
Fixes an issue where binary_auroc and other classification metrics return incorrect results when logits are very large (>16.7 for float32, >36.7 for float64). The sigmoid function overflows to exactly 1.0 for all such values, losing the ranking information needed for AUROC calculation.
Problem
When all logits are in a large range (e.g., 97-100), naive sigmoid application causes numerical overflow:
import torch
from torchmetrics.functional.classification import binary_auroc
preds = torch.tensor([98.0950, 98.4612, 98.1145, 98.1506, 97.6037, 98.9425,
99.2644, 99.5014, 99.7280, 99.6595, 99.6931, 99.4667,
99.9623, 99.8949, 99.8768])
labels = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
# Before fix: returns 0.5 (random guessing)
# After fix: returns 0.9286 (correct)
binary_auroc(preds, labels)
The issue occurs because sigmoid(x) for x > 16.7 evaluates to exactly 1.0 in float32, making all predictions indistinguishable and destroying the ranking information that AUROC depends on.
Solution
Modified normalize_logits_if_needed in src/torchmetrics/utilities/compute.py to apply numerically stable sigmoid when needed:
-
Conditional stabilization: Only applies when
min(logits) > 15, indicating all values will overflow - Preserves ranking: Subtracts max value before sigmoid, maintaining relative ordering since sigmoid is monotonic
- Avoids artificial ties: Does not apply stabilization to mixed-range logits (e.g., -5 to 100) where it would create spurious ties
- Backward compatible: Normal-range logits use standard sigmoid, maintaining existing behavior
Changes
- Updated
normalize_logits_if_needed()to check both min and max values before stabilization - Added comprehensive regression test covering:
- Original issue case (logits 97-100)
- Very large logits (200+)
- Mixed range logits (-5 to 100)
Testing
All existing tests pass:
- ✅ 92 binary AUROC tests passed
- ✅ 30 precision-recall curve tests passed
- ✅ 30 stat_scores tests passed
- ✅ New regression test with 3 cases added
Closes #XXXX
Original prompt
This section details on the original issue you should resolve
<issue_title>
torchmetrics.functional.classification.binary_aurocgives wrong results when logits are large</issue_title> <issue_description>## 🐛 Bug
torchmetrics.functional.classification.binary_aurocalways gives 0.5 when all logits are large. This seems to be caused by a floating point precision error with sigmoid.To Reproduce
Code sample
import torch import torchmetrics.functional.classification preds = torch.tensor([98.0950, 98.4612, 98.1145, 98.1506, 97.6037, 98.9425, 99.2644, 99.5014, 99.7280, 99.6595, 99.6931, 99.4667, 99.9623, 99.8949, 99.8768]) labels = torch.tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) torchmetrics.functional.classification.binary_auroc(preds, labels)Output:
tensor(0.5000)Expected behavior
AUROC of the above example should be 0.9286, as computed by
sklearn.import sklearn.metrics sklearn.metrics.roc_auc_score(labels, preds)Output:
0.9285714285714286Environment
- Windows 11 24H2
- Python version 3.10.11
- TorchMetrics version 1.4.3
- PyTorch version 2.4.1+cu124
Additional context
This appears to be a problem of floating point precision with sigmoid at line 185 in function
_binary_precision_recall_curve_formatin filetorchmetrics/src/torchmetrics/functional/classification/precision_recall_curve.py.I extracted all the necessary functions and made a miniature
binary_aurocfunction that uses exactly the same algorithm (works for the above example, did not test for other examples):def binary_auroc( preds: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: preds = preds.sigmoid() print(preds) desc_score_indices = torch.argsort(preds, descending=True) preds = preds[desc_score_indices] target = target[desc_score_indices] # print(preds, target) # pred typically has many tied values. Here we extract # the indices associated with the distinct values. We also # concatenate a value for the end of the curve. distinct_value_indices = torch.nonzero(preds[1:] - preds[:-1], as_tuple=True)[0] # print(distinct_value_indices) threshold_idxs = torch.nn.functional.pad(distinct_value_indices, [0, 1], value=target.size(0) - 1) # print(threshold_idxs) tps = torch.cumsum(target, dim=0)[threshold_idxs] fps = 1 + threshold_idxs - tps # print(tps, fps) # Add an extra threshold position to make sure that the curve starts at (0, 0) tps = torch.cat([torch.zeros(1, dtype=tps.dtype, device=tps.device), tps]) fps = torch.cat([torch.zeros(1, dtype=fps.dtype, device=fps.device), fps]) tpr = tps / tps[-1] fpr = fps / fps[-1] # print(fpr, tpr) return torch.trapezoid(tpr, fpr, dim=-1)Output:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]) tensor(0.5000)
preds = preds.sigmoid()converts all logits to 1 as if all the logits are the same, which is not the case. The maximum magnitude of a logit must be less than 36.74 fordoubleor 16.64 forfloat32to avoid being converted to exactly 1.Suggested fix
It's probably a good idea to scale the raw logits before sigmoid, something like below:
preds /= torch.max(torch.abs(preds)) # scales max element to 1 preds = preds.sigmoid()All functions that applies sigmoid to raw ogits will need such a fix.</issue_description>
Comments on the Issue (you are @copilot in this section)
@Borda > ### Suggested fix > It's probably a good idea to scale the raw logits before sigmoid, something like below: > ```py > preds /= torch.max(torch.abs(preds)) # scales max element to 1 > preds = preds.sigmoid() > ``` > All functions that applies sigmoid to raw ogits will need such a fix.That sounds reasonable to me... @SkafteNicki your thoughts? 🤔 </comment_new>
💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.
📚 Documentation preview 📚: https://torchmetrics--3283.org.readthedocs.build/en/3283/