ignite icon indicating copy to clipboard operation
ignite copied to clipboard

Approximate PR/ROC AUC Metics

Open vfdev-5 opened this issue 4 years ago • 13 comments

🚀 Feature

Taken from this colab by @EricZimmermann

Description

Problem

For large tensors, computing AUC metrics over multiple thresholds is exhaustive and slow. For a sufficiently large dataset, caching or saving outputs is too expensive and must be done in post.

Solution

Assuming the distribution of values are known and bounded, approximate the an integral via riemann sum over a set of fixed or variable step sizes.

Let a set of monotonically increasing thresholds $T ={t_1, t_2, \dots, t_n}$ be the step sizes used to approximate an integral (AUC). For each iteration, cache the counts (build a histogram) of each value falling into a bin between two thresholds. When complete, approximate a statistic using a ratio of counts.

Place calc on GPU for speed removing cpu bottleneck at each iteration

Applications

Optimizing a model for voxel level (each voxel treated as an independant sample) PR AUC / ROC AUC, ex: semantic pathology segmentation

Enable user to validate model based on best operating point setting (F1 for non 0.5 threshold)

Limitations

Domain (as thresholds) must be known ahead of time. This can be accounted for by setting a lot of thresholds over a very wide range where num thresholds << num voxels in output tensor

Future work can use a heuristic to widen / narrow num thresholds based on previous iterations

Context:

  • https://github.com/Project-MONAI/MONAI/discussions/2982#discussioncomment-1621513

Code:

  • https://colab.research.google.com/drive/17Xo80swTxcaZsZOKS08GI-LHYHkg5kud

vfdev-5 avatar Nov 11 '21 14:11 vfdev-5

@vfdev-5

We've updated the colab notebook with more metrics and updated the test code with pytest methods.

JustinSzeto avatar Nov 30 '21 23:11 JustinSzeto

I searched a little bit and found that for computing AUC there are three approaches:

  1. Computing it exactly. It takes O(NlogN) in time and O(N) in space. It is computed on whole data at once. Scikit-learn and Transformers.evaluate use this approach
  2. Approximating it using thresholds either given by user or generated internally. Tensorflow uses this approach and and implements it in O(NT) in time and O(T) in space. @EricZimmermann 's implementation does that in O((N+T)logN) in time and O(T) in space.
  3. NEW! Approximating it using an unbiased estimator with known variance, specifically Wilcoxon-Mann-Whitney statistic in O(NlogN) in time and O(1) in space. This statistic computed AUC-ROC . For AUC-PR another estimator should be used.

By space measures I mean the space being used between two consecutive batch updates, not the temporary space being used during a batch update.

@JustinSzeto, @EricZimmermann what's your thoughts on that?

sadra-barikbin avatar Jun 26 '22 07:06 sadra-barikbin

@sadra-barikbin

I'm a little confused on what you're asking. The biggest use-case for this class is the aggregation of large sets of labels over time that cannot be stored in RAM done as a confusion matrix over discrete elements. The confusion matrix allows you to compute a set of statistics all in O(T) at the end of an epoch or run. If use the Wilconon-Mann-Whitney statistic, how would you aggregate them over mini-batches, assuming you cannot cache outputs from batch to batch? What would you do for all other stats? I see the application for just AOC-AUC but then you have to cache everything in memory until the end of an epoch which is what we try to avoid as well.

Ex: we want to perform image segmentation on large images. Each image has 1e6 labels at 16-bits each. If youre dataset is large, storing this in ram is costly. Secondly, at the end of each epoch, computing all statistics would take up cycles and slow things down considerably.

I think we could probably modify how the confusion matrix is done, considering there is a recursive threshold way that can work without the sorting in O(TN). (values between t_{i} and t_{i+1} would be values less than t_{i+1} minus values less that t_i). We could looking into how tf does it but I imagine this is how its done. I also want to note that these time complexities are a little tough to analyze in a parallelized system. We initially relied on the sort which is O(NlogN) to actually be done much faster since operations are not sequential. After that the cumsum is O(N) and the search is O(TlogN) (again parallel) . Doing this directly would be O(NT).

EricZimmermann avatar Jun 26 '22 18:06 EricZimmermann

You're right that trying to compute the metric at the end of run is excruciating both in terms of time and in terms of memory, given large data samples.

Wilconon-Mann-Whitney could still be used without facing problems you mentioned by making a minor modification in it. Specifically by comparing predictions in a batch just with themselves and not with those in other batches. This way we only keep a single float in memory between batches. Am I correct?

About your implementation vs. that of Tensorflow, yours is better I think because not only uses less memory during computation of a batch (O(N)vs. O(NT) ) but also takes less time if T>log(N) which might often be the case. Is this analysis right?

sadra-barikbin avatar Jun 27 '22 04:06 sadra-barikbin

@sadra-barikbin If you keep the running AUC, how to you sum them together batch to batch if you want to approximate AUC over the epoch or validation set?

Yes you're right! We wrote this with medical image segmentation in mind (kind of like Brats) so these are relatively large assumptions which seem to work out in this use case.

EricZimmermann avatar Jun 27 '22 13:06 EricZimmermann

image

According to the formulas above, using two float variables initialized with zero at the beginning of the run we do two first formulas at each batch and finally do a simple division at the end of the run.

sadra-barikbin avatar Jun 28 '22 04:06 sadra-barikbin

How would you integrate this is with the rest of the PR? The whole basis of this is to also have access to everything you could compute with a confusion matrix as well. I see this being nice somewhere... but not sure if this fits the remainder of the requirements.

EricZimmermann avatar Jun 28 '22 19:06 EricZimmermann

What requirements do we have other than computing AUC itself?

sadra-barikbin avatar Jun 28 '22 22:06 sadra-barikbin

We have all the confusion matrix derived stats (f1, mcc, balanced acc, etc)

EricZimmermann avatar Jun 28 '22 22:06 EricZimmermann

You're right that using Wilconon-Mann-Whitney, we cannot compute those you mentioned, but their implementation (of O(N) in time) does already exist in Ignite so user could use that. What advantages does this implementation provide for those derived stats?

sadra-barikbin avatar Jun 28 '22 23:06 sadra-barikbin

For segmentation there are class imbalances so looking at pr auc is useful. A big topic is detection / segmentation model calibration so this allows us to track these things over time.

EricZimmermann avatar Jun 28 '22 23:06 EricZimmermann

Your first statement is correct. About the second one, by tracking things over time you mean to have those derived stats over epochs?

sadra-barikbin avatar Jun 29 '22 04:06 sadra-barikbin

over batch updates with minimal compute requirements

EricZimmermann avatar Jul 04 '22 19:07 EricZimmermann