scikit-learn icon indicating copy to clipboard operation
scikit-learn copied to clipboard

FEA add TunedThresholdClassifier meta-estimator to post-tune the cut-off threshold

Open glemaitre opened this issue 1 year ago • 17 comments

superseded #16525 closes #16525 closes https://github.com/scikit-learn/scikit-learn/issues/8614 closes https://github.com/scikit-learn/scikit-learn/pull/10117 supersedes https://github.com/scikit-learn/scikit-learn/pull/10117

build upon https://github.com/scikit-learn/scikit-learn/pull/26037

relates to #4813

Summary

We introduce a TunedThresholdClassifier that intends to post-tune the cut-off points to convert a soft decision of the decision_function or predict_proba to a hard decision provided by predict.

Important features to have in mind:

objective_metric: the objective metric is set to either a metric to be maximized or a pair of metrics, one to be optimized under the constraint of the other (to find a trade-off). Additionally, we can pass a cost/gain-matrix that could be used to optimize a business metric. For this case, we are limited to constant cost/gain. In the future, we can think of cost/gain that depends on the matrix X but we would need to be able to forward meta-data to the scorer (a good additional use case for SLEP006 @adrinjalali).

cv and refit: we provide some flexibility to pass refitted model and single train/test split. We add limitations and documentation to caveats with an example.

Point to discuss

  • Are we fine with the name TunedThresholdClassifier? Shall instead have something about "threshold" (e.g. ThresholdTuner)?
  • We are using the term objective_metric, constraint_value and objective_score. Is the naming fine? An alternative to "objective" might be "utility"?

Further work

I implemented currently a single example that shows the feature in the context of post-tuning of the decision threshold.

The current example is using a single train/test split for the figure and I think it would be nice to have some ROC/precision-recall curve obtained from cross-validation to be complete. However, we need some new features to be implemented.

I also am planning to analyse the usage of this feature on the problem of calibration on imbalanced classification problems. The feeling on this topic is that resampling strategies involved an implicit tuning of the decision threshold at the cost of a badly calibrated model. It might be better to learn a model on the imbalanced problem directly, making sure that it is well calibrated and then post-tune the decision threshold for "hard" prediction. In this case, you get the best of two worlds: a calibrated model if the output of predict_proba is important to you and an optimum hard predictor for your specific utility metric. However, this is going to need some investigation and will be better suited for another PR.

glemaitre avatar Apr 07 '23 12:04 glemaitre

@glemaitre Thank you very much for this (extensive! due to examples, I see) PR.

To let you know my current point of view, let us focus on binary classification. First of all, YES we urgently need tools for threshold selection and make it prominent. I even have been considering to write a SLEP to add a threshold parameter to all classifiers, at least the ones with predict_proba because the implicit 0.5 has so many consequences of bad practice. I see 3 steps:

  1. Add possibility to manual set a threshold. This is really important in my opinion. In particular for binary classification, the cost ratio parameter is the optimal threshold. And there might be cases where the different costs for false positives and false negatives are known.
  2. Add tools for automatic threshold selection (given a fitted model). This is a 1-dimensional optimization problem, so we could just use scipy.optimize. BTW, the ROC curve was invented for exactly this...
  3. Add CV-tools for automatic threshold selection.

This PR immediately starts with point 3 and that is my main concern.

Consider the use case: Assume I have a good model for predict_proba (maybe found via cv), now that is fixed. I apply that model in my application, and I choose a threshold for decisions/actions. By tracking the performance, I observe a certain shift of my KPIs. The first thing I can do is to tweak the threshold a bit, but let the model otherwise be untouched (no new training data, or too expensive to re-train, or etc.).

lorentzenchr avatar May 04 '23 14:05 lorentzenchr

I think that 2. is also addressed with cv="prefit". We don't use scipy.optimize here and instead have a brute-force approach.

We could still implement 1. with an additional option for objective_metric="custom" and an additional parameter to set the decision threshold.

glemaitre avatar May 04 '23 14:05 glemaitre

We don't use scipy.optimize here and instead have a brute-force approach.

And indeed, for the precision-recall trade-off (e.g. to optimize for recall at a fixed precision constraint or the converse), it would not be possible to use scipy.optimize to tune the threshold as the PR curve is not concave in general. I think the brute force approach is fast enough in practice, simple to implement, easy to understand/debug/maintain and does not introduce any additional tunable parameter.

ogrisel avatar Jun 01 '23 14:06 ogrisel

A possible follow-up for this PR: using SLEP 6 metadata to pass extra information used to compute sample-wise weights for the cost-matrix.

For instance, the cost of a false positive might vary and depend on side metadata that is neither part of X or y.

ogrisel avatar Jun 01 '23 14:06 ogrisel

Are we fine with the name CutOffClassifier? Shall instead have something about "threshold" (e.g. ThresholdTuner)?

I think I prefer the name ThresholdTuner as this is strictly speaking not a classifier itself and I would rather avoid future regrets. Maybe we can make allusion to the possibility of using cross-validation and name it something like ThresholdTunerCV, but it's also not very convincing as the cross-val can be bypassed.

In any case, I think we should discuss a proper name for such an important meta-estimator :)

ArturoAmorQ avatar Jun 01 '23 14:06 ArturoAmorQ

In the context of @lorentzenchr's plan (https://github.com/scikit-learn/scikit-learn/pull/26120#issuecomment-1534891010), if we change the ClassifierMixin class to make it possible to choose a non-zero threshold for binary decision_function or a non-0.5 threshold for binary predict_proba, then the threshold tuner could be written as a function instead of meta-estimator (a bit like cross_val_score). But from an API point of view, this would maybe be a bit weird to have a function that mutates an estimator (other than it's fit and set_params methods).

I still find it interesting to have meta-estimator though to make it more direct to compute metrics that depend on hard predictions (e.g. classification reports, confusion matrix, and multimetrics hyper-parameter search reports) on the outcome of a simple fit call.

ogrisel avatar Jun 01 '23 14:06 ogrisel

I think I like ThresholdTunerCV or ClassificationThresholdTunerCV to be a bit more explicit (albeit maybe too verbose).

ogrisel avatar Jun 01 '23 14:06 ogrisel

At least the CI is green enough to be able to read the rendered HTML for the example:

https://output.circle-artifacts.com/output/job/2fd8bbc2-a7a0-4509-a9b7-09fa52c7cb0f/artifacts/0/doc/auto_examples/model_selection/plot_cutoff_tuning.html

ogrisel avatar Jun 01 '23 14:06 ogrisel

In the context of @lorentzenchr's plan (#26120 (comment)), if we change the ClassifierMixin class to make it possible to choose a non-zero threshold for binary decision_function or a non-0.5 threshold for binary predict_proba ...

If we only had binary classifiers, then I would indeed push to include the threshold parameter in ClassifierMixin. For multiclass settings, this gets promoted to a matrix (with some redundancy), see e.g. https://arxiv.org/abs/1704.08979 for 3 classes. I'm undecided whether users and maintainers could be convinced to appreciate such parameters via ClassifierMixin. I still think it would be the cleanest solution, otherwise predict is just nonsense without being noted.

The fact that it is hard to decide on a good name for the current meta-estimator, e.g. include CV at the end of the name or not, is a sign, IMHO, that it tries to do too many things at once.

If we go with it, ClassificationThresholdTunerCV might indeed be the correct name. (BTW, HistGradientBoostingClassifier has the exact same name length.)

And don't get me wrong: Great work @glemaitre, I really appreciate this PR!

lorentzenchr avatar Jun 01 '23 17:06 lorentzenchr

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 9bd68e6. Link to the linter CI: here

github-actions[bot] avatar Jul 03 '23 10:07 github-actions[bot]

It'd probably make sense to add metadata routing here from the start. Let me know if I should add it, or if you do, let me know and I can review it.

adrinjalali avatar Jul 07 '23 10:07 adrinjalali

It'd probably make sense to add metadata routing here from the start. Let me know if I should add it, or if you do, let me know and I can review it.

Indeed, I have two features to add:

  • Use metadata routing
  • Enable a manual tuning

I will ping you when I am facing issue.

glemaitre avatar Jul 07 '23 10:07 glemaitre

Thanks @solegalli. I will address the comment. I need first to make the metadata routing work properly and add some additional tests before revisiting the documentation.

glemaitre avatar Jul 14 '23 20:07 glemaitre

Requires a merge of https://github.com/scikit-learn/scikit-learn/pull/26840

glemaitre avatar Sep 28 '23 09:09 glemaitre

To remove the remaining failure, we would need to bump from pandas 1.0 to 1.5. Since this is a bit drastic, I can try to handle the plot with matplotlib. WDYT @ogrisel ?

glemaitre avatar Dec 04 '23 16:12 glemaitre

To remove the remaining failure, we would need to bump from pandas 1.0 to 1.5.

I was wrong. We only need to bump to 1.1. So I assume this is enough no need to change the plot.

glemaitre avatar Dec 04 '23 19:12 glemaitre

Given the constraints for review, what do you think about making TunedThresholdClassifier experimental for 1.4?

thomasjpfan avatar Dec 10 '23 02:12 thomasjpfan

I think that this is good for another round of review.

glemaitre avatar May 03 '24 15:05 glemaitre

LGTM

Deja-vu :)

glemaitre avatar May 03 '24 15:05 glemaitre

🎉

jeremiedbb avatar May 03 '24 16:05 jeremiedbb

The issue about refactoring is already open here: https://github.com/scikit-learn/scikit-learn/issues/28941.

For the documentation, I'll open one issue to see how to articulate the constrained part. The comment raised by @amueller is not anymore meaningful since we don't really allow to choose a point on the PR or ROC curve in a straight forward manner.

I'll also revive https://github.com/scikit-learn/scikit-learn/pull/17930

glemaitre avatar May 03 '24 16:05 glemaitre

:rocket: @glemaitre 🎉 Thanks for this great addition many have been longing for.

In a way, I like the explicit FixedThresholdClassifier. Out of curiosity: is there a path forward that this could end up as mixing class in every classifier? Without the need to meta-class-wrap it by the user?

lorentzenchr avatar May 03 '24 18:05 lorentzenchr

Out of curiosity: is there a path forward that this could end up as mixing class in every classifier? Without the need to meta-class-wrap it by the user?

Maybe but that would entail adding at least 2 new constructor parameters to all classifiers in scikit-learn that implement predict_proba or decision_function. That's kind of an invasive API change...

ogrisel avatar May 06 '24 16:05 ogrisel