torchmetrics
torchmetrics copied to clipboard
Constant-memory implementation of common threshold-varying metrics
Thanks to the developers for their great work. I (and my work) use this package heavily and will continue to do so :).
🚀 Feature
Related to #128, it would be great to also have constant-memory implementations of the ROC curve and the AUC metric. Given that this has already been implemented for precision-recall, the work is 90% done, it just needs a minor extension.
Motivation
The current implementation of AUROC consumes a lot of memory.
Pitch
I am happy to submit a PR for this (together with help from @norrishd, @ronrest, @BlakeJC94, & @ryanmseer), and would propose to do it by:
- Using some logic similar to what is in
BinnedPrecisionRecallCurve
right now, write a more general base class that implements a "confusion matrix curve", where each of the TP, TN, FP, FN are calculated for a set of defined threshold values. An argument to the parent class could allow child classes to specify a subset of these 4 values to prevent unnecessary calculation. The base class could be called something likeConfusionMatrixCurve
,StatScoresCurve
, orBinnedStatScores
(preferences welcome). - Simplify
BinnedPrecisionRecallCurve
by inheriting from this base class. - Implement a constant memory version of the ROC called
BinnedROC
that inherits from the base class. - Similarly, we could also add constant memory implementations of things like
BinnedAUROC
, andBinnedSpecificityAtFixedSensitivity
.
I've implemented similar functionality in a different metrics package here, but for this implementation I would of course follow the patterns and conventions of torchmetrics as best as I could.
Alternatives
Hi! thanks for your contribution!, great first issue!
I'm currently struggling with AUROC's memory footprint, so this would honestly be amazing!
Hi @timesler,
This sounds like a great addition to torchmetrics. We are aware that some of our implementations suffer from a potential huge memory footprint, so this would be a great to have metrics that solves this.
Please feel free to send a PR :]
One small question: would it make sense to combine binned and non-binned metrics into a single metric with a parameter to change between them e.g.
class AUROC(Metric):
def __init__(self, reduce_memory=True/False):
self.reduce_memory = reduce_memory
...
def update(self, ...):
if self.reduce_memory:
binned_update(...)
else:
current_update(...)
instead of having multiple of the same metric?
I'm beginning some exploratory work on this feature now in consultation with @timesler .
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
We're still working on this.
Issue will be fixed by classification refactor: see this issue https://github.com/Lightning-AI/metrics/issues/1001 and this PR https://github.com/Lightning-AI/metrics/pull/1195 for all changes
Small recap: This issue ask for constant memory implementations of common metrics like roc
ect. After the refactor, metrics such as roc
, auroc
, precision_recall
, average_precision
all now supports constant memory implementations by providing the thresholds
argument. If thresholds=None
the standard approach is used (which is accurate but not memory constant) and if thresholds=10
(int) or thresholds=[0.1, 0.3, 0.5, 0.7]
(list of floats) an binned version will be used that is less accuracy but memory constant. Example below for roc:
from torchmetrics.functional import binary_roc
import torch
preds = torch.rand(20)
target = torch.randint(0, 2, (20,))
binary_roc(preds, target, thresholds=None) # accurate and memory intensive
# (tensor([0.0000, 0.1667, 0.3333, 0.3333, 0.3333, 0.3333, 0.5000, 0.6667, 0.6667,
# 0.6667, 0.6667, 0.8333, 0.8333, 0.8333, 1.0000, 1.0000, 1.0000, 1.0000,
# 1.0000, 1.0000, 1.0000]),
# tensor([0.0000, 0.0000, 0.0000, 0.0714, 0.1429, 0.2143, 0.2143, 0.2143, 0.2857,
# 0.3571, 0.4286, 0.4286, 0.5000, 0.5714, 0.5714, 0.6429, 0.7143, 0.7857,
# 0.8571, 0.9286, 1.0000]),
# tensor([1.0000, 0.9995, 0.8895, 0.8621, 0.8426, 0.8204, 0.8044, 0.7560, 0.7169,
# 0.7023, 0.6685, 0.6194, 0.5599, 0.5071, 0.4728, 0.4574, 0.4332, 0.2989,
# 0.2535, 0.2446, 0.2025]))
binary_roc(preds, target, thresholds=10) # less accuracy and memory constant
# (tensor([0.0000, 0.3333, 0.5000, 0.6667, 0.8333, 1.0000, 1.0000, 1.0000, 1.0000,
# 1.0000]),
# tensor([0.0000, 0.0000, 0.2143, 0.4286, 0.5000, 0.6429, 0.7143, 0.9286, 1.0000,
# 1.0000]),
# tensor([1.0000, 0.8889, 0.7778, 0.6667, 0.5556, 0.4444, 0.3333, 0.2222, 0.1111,
# 0.0000]))
Issue will be closed when https://github.com/Lightning-AI/metrics/pull/1195 is merged.