torchmetrics icon indicating copy to clipboard operation
torchmetrics copied to clipboard

Constant-memory implementation of common threshold-varying metrics

Open timesler opened this issue 2 years ago • 6 comments

Thanks to the developers for their great work. I (and my work) use this package heavily and will continue to do so :).

🚀 Feature

Related to #128, it would be great to also have constant-memory implementations of the ROC curve and the AUC metric. Given that this has already been implemented for precision-recall, the work is 90% done, it just needs a minor extension.

Motivation

The current implementation of AUROC consumes a lot of memory.

Pitch

I am happy to submit a PR for this (together with help from @norrishd, @ronrest, @BlakeJC94, & @ryanmseer), and would propose to do it by:

  1. Using some logic similar to what is in BinnedPrecisionRecallCurve right now, write a more general base class that implements a "confusion matrix curve", where each of the TP, TN, FP, FN are calculated for a set of defined threshold values. An argument to the parent class could allow child classes to specify a subset of these 4 values to prevent unnecessary calculation. The base class could be called something like ConfusionMatrixCurve, StatScoresCurve, or BinnedStatScores (preferences welcome).
  2. Simplify BinnedPrecisionRecallCurve by inheriting from this base class.
  3. Implement a constant memory version of the ROC called BinnedROC that inherits from the base class.
  4. Similarly, we could also add constant memory implementations of things like BinnedAUROC, and BinnedSpecificityAtFixedSensitivity.

I've implemented similar functionality in a different metrics package here, but for this implementation I would of course follow the patterns and conventions of torchmetrics as best as I could.

Alternatives

timesler avatar Nov 16 '21 01:11 timesler

Hi! thanks for your contribution!, great first issue!

github-actions[bot] avatar Nov 16 '21 01:11 github-actions[bot]

I'm currently struggling with AUROC's memory footprint, so this would honestly be amazing!

miccio-dk avatar Nov 19 '21 15:11 miccio-dk

Hi @timesler,

This sounds like a great addition to torchmetrics. We are aware that some of our implementations suffer from a potential huge memory footprint, so this would be a great to have metrics that solves this.

Please feel free to send a PR :]

One small question: would it make sense to combine binned and non-binned metrics into a single metric with a parameter to change between them e.g.

class AUROC(Metric):
    def __init__(self, reduce_memory=True/False):
        self.reduce_memory = reduce_memory
        ...

    def update(self, ...):
        if self.reduce_memory:
            binned_update(...)
        else:
            current_update(...)

instead of having multiple of the same metric?

SkafteNicki avatar Nov 24 '21 12:11 SkafteNicki

I'm beginning some exploratory work on this feature now in consultation with @timesler .

ryandaryl avatar Jan 05 '22 07:01 ryandaryl

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

stale[bot] avatar Mar 19 '22 16:03 stale[bot]

We're still working on this.

ryandaryl avatar Mar 20 '22 22:03 ryandaryl

Issue will be fixed by classification refactor: see this issue https://github.com/Lightning-AI/metrics/issues/1001 and this PR https://github.com/Lightning-AI/metrics/pull/1195 for all changes

Small recap: This issue ask for constant memory implementations of common metrics like roc ect. After the refactor, metrics such as roc, auroc, precision_recall, average_precision all now supports constant memory implementations by providing the thresholds argument. If thresholds=None the standard approach is used (which is accurate but not memory constant) and if thresholds=10 (int) or thresholds=[0.1, 0.3, 0.5, 0.7] (list of floats) an binned version will be used that is less accuracy but memory constant. Example below for roc:

from torchmetrics.functional import binary_roc
import torch

preds = torch.rand(20)
target = torch.randint(0, 2, (20,))

binary_roc(preds, target, thresholds=None)  # accurate and memory intensive 
# (tensor([0.0000, 0.1667, 0.3333, 0.3333, 0.3333, 0.3333, 0.5000, 0.6667, 0.6667,
#         0.6667, 0.6667, 0.8333, 0.8333, 0.8333, 1.0000, 1.0000, 1.0000, 1.0000,
#         1.0000, 1.0000, 1.0000]),
# tensor([0.0000, 0.0000, 0.0000, 0.0714, 0.1429, 0.2143, 0.2143, 0.2143, 0.2857,
#         0.3571, 0.4286, 0.4286, 0.5000, 0.5714, 0.5714, 0.6429, 0.7143, 0.7857,
#         0.8571, 0.9286, 1.0000]),
# tensor([1.0000, 0.9995, 0.8895, 0.8621, 0.8426, 0.8204, 0.8044, 0.7560, 0.7169,
#         0.7023, 0.6685, 0.6194, 0.5599, 0.5071, 0.4728, 0.4574, 0.4332, 0.2989,
#         0.2535, 0.2446, 0.2025]))
binary_roc(preds, target, thresholds=10)   # less accuracy and memory constant
# (tensor([0.0000, 0.3333, 0.5000, 0.6667, 0.8333, 1.0000, 1.0000, 1.0000, 1.0000,
#         1.0000]),
#  tensor([0.0000, 0.0000, 0.2143, 0.4286, 0.5000, 0.6429, 0.7143, 0.9286, 1.0000,
#         1.0000]),
# tensor([1.0000, 0.8889, 0.7778, 0.6667, 0.5556, 0.4444, 0.3333, 0.2222, 0.1111,
#         0.0000]))

Issue will be closed when https://github.com/Lightning-AI/metrics/pull/1195 is merged.

SkafteNicki avatar Aug 29 '22 13:08 SkafteNicki