scikit-learn
scikit-learn copied to clipboard
FEA add TunedThresholdClassifier meta-estimator to post-tune the cut-off threshold
superseded #16525 closes #16525 closes https://github.com/scikit-learn/scikit-learn/issues/8614 closes https://github.com/scikit-learn/scikit-learn/pull/10117 supersedes https://github.com/scikit-learn/scikit-learn/pull/10117
build upon https://github.com/scikit-learn/scikit-learn/pull/26037
relates to #4813
Summary
We introduce a TunedThresholdClassifier
that intends to post-tune the cut-off points to convert a soft decision of the decision_function
or predict_proba
to a hard decision provided by predict
.
Important features to have in mind:
objective_metric
: the objective metric is set to either a metric to be maximized or a pair of metrics, one to be optimized under the constraint of the other (to find a trade-off). Additionally, we can pass a cost/gain-matrix that could be used to optimize a business metric. For this case, we are limited to constant cost/gain. In the future, we can think of cost/gain that depends on the matrix X
but we would need to be able to forward meta-data to the scorer (a good additional use case for SLEP006 @adrinjalali).
cv
and refit
: we provide some flexibility to pass refitted model and single train/test split. We add limitations and documentation to caveats with an example.
Point to discuss
- Are we fine with the name
TunedThresholdClassifier
? Shall instead have something about "threshold" (e.g.ThresholdTuner
)? - We are using the term
objective_metric
,constraint_value
andobjective_score
. Is the naming fine? An alternative to "objective" might be "utility"?
Further work
I implemented currently a single example that shows the feature in the context of post-tuning of the decision threshold.
The current example is using a single train/test split for the figure and I think it would be nice to have some ROC/precision-recall curve obtained from cross-validation to be complete. However, we need some new features to be implemented.
I also am planning to analyse the usage of this feature on the problem of calibration on imbalanced classification problems. The feeling on this topic is that resampling strategies involved an implicit tuning of the decision threshold at the cost of a badly calibrated model. It might be better to learn a model on the imbalanced problem directly, making sure that it is well calibrated and then post-tune the decision threshold for "hard" prediction. In this case, you get the best of two worlds: a calibrated model if the output of predict_proba
is important to you and an optimum hard predictor for your specific utility metric. However, this is going to need some investigation and will be better suited for another PR.
@glemaitre Thank you very much for this (extensive! due to examples, I see) PR.
To let you know my current point of view, let us focus on binary classification. First of all, YES we urgently need tools for threshold selection and make it prominent.
I even have been considering to write a SLEP to add a threshold
parameter to all classifiers, at least the ones with predict_proba
because the implicit 0.5 has so many consequences of bad practice. I see 3 steps:
- Add possibility to manual set a threshold. This is really important in my opinion. In particular for binary classification, the cost ratio parameter is the optimal threshold. And there might be cases where the different costs for false positives and false negatives are known.
- Add tools for automatic threshold selection (given a fitted model). This is a 1-dimensional optimization problem, so we could just use
scipy.optimize
. BTW, the ROC curve was invented for exactly this... - Add CV-tools for automatic threshold selection.
This PR immediately starts with point 3 and that is my main concern.
Consider the use case: Assume I have a good model for predict_proba
(maybe found via cv), now that is fixed. I apply that model in my application, and I choose a threshold for decisions/actions. By tracking the performance, I observe a certain shift of my KPIs. The first thing I can do is to tweak the threshold a bit, but let the model otherwise be untouched (no new training data, or too expensive to re-train, or etc.).
I think that 2. is also addressed with cv="prefit"
. We don't use scipy.optimize
here and instead have a brute-force approach.
We could still implement 1. with an additional option for objective_metric="custom"
and an additional parameter to set the decision threshold.
We don't use scipy.optimize here and instead have a brute-force approach.
And indeed, for the precision-recall trade-off (e.g. to optimize for recall at a fixed precision constraint or the converse), it would not be possible to use scipy.optimize
to tune the threshold as the PR curve is not concave in general. I think the brute force approach is fast enough in practice, simple to implement, easy to understand/debug/maintain and does not introduce any additional tunable parameter.
A possible follow-up for this PR: using SLEP 6 metadata to pass extra information used to compute sample-wise weights for the cost-matrix.
For instance, the cost of a false positive might vary and depend on side metadata that is neither part of X
or y
.
Are we fine with the name
CutOffClassifier
? Shall instead have something about "threshold" (e.g.ThresholdTuner
)?
I think I prefer the name ThresholdTuner
as this is strictly speaking not a classifier itself and I would rather avoid future regrets. Maybe we can make allusion to the possibility of using cross-validation and name it something like ThresholdTunerCV
, but it's also not very convincing as the cross-val can be bypassed.
In any case, I think we should discuss a proper name for such an important meta-estimator :)
In the context of @lorentzenchr's plan (https://github.com/scikit-learn/scikit-learn/pull/26120#issuecomment-1534891010), if we change the ClassifierMixin
class to make it possible to choose a non-zero threshold for binary decision_function
or a non-0.5 threshold for binary predict_proba
, then the threshold tuner could be written as a function instead of meta-estimator (a bit like cross_val_score
). But from an API point of view, this would maybe be a bit weird to have a function that mutates an estimator (other than it's fit
and set_params
methods).
I still find it interesting to have meta-estimator though to make it more direct to compute metrics that depend on hard predictions (e.g. classification reports, confusion matrix, and multimetrics hyper-parameter search reports) on the outcome of a simple fit
call.
I think I like ThresholdTunerCV
or ClassificationThresholdTunerCV
to be a bit more explicit (albeit maybe too verbose).
At least the CI is green enough to be able to read the rendered HTML for the example:
https://output.circle-artifacts.com/output/job/2fd8bbc2-a7a0-4509-a9b7-09fa52c7cb0f/artifacts/0/doc/auto_examples/model_selection/plot_cutoff_tuning.html
In the context of @lorentzenchr's plan (#26120 (comment)), if we change the
ClassifierMixin
class to make it possible to choose a non-zero threshold for binarydecision_function
or a non-0.5 threshold for binarypredict_proba
...
If we only had binary classifiers, then I would indeed push to include the threshold parameter in ClassifierMixin
. For multiclass settings, this gets promoted to a matrix (with some redundancy), see e.g. https://arxiv.org/abs/1704.08979 for 3 classes. I'm undecided whether users and maintainers could be convinced to appreciate such parameters via ClassifierMixin
. I still think it would be the cleanest solution, otherwise predict
is just nonsense without being noted.
The fact that it is hard to decide on a good name for the current meta-estimator, e.g. include CV at the end of the name or not, is a sign, IMHO, that it tries to do too many things at once.
If we go with it, ClassificationThresholdTunerCV
might indeed be the correct name. (BTW, HistGradientBoostingClassifier
has the exact same name length.)
And don't get me wrong: Great work @glemaitre, I really appreciate this PR!
✔️ Linting Passed
All linting checks passed. Your pull request is in excellent shape! ☀️
It'd probably make sense to add metadata routing here from the start. Let me know if I should add it, or if you do, let me know and I can review it.
It'd probably make sense to add metadata routing here from the start. Let me know if I should add it, or if you do, let me know and I can review it.
Indeed, I have two features to add:
- Use metadata routing
- Enable a manual tuning
I will ping you when I am facing issue.
Thanks @solegalli. I will address the comment. I need first to make the metadata routing work properly and add some additional tests before revisiting the documentation.
Requires a merge of https://github.com/scikit-learn/scikit-learn/pull/26840
To remove the remaining failure, we would need to bump from pandas 1.0 to 1.5. Since this is a bit drastic, I can try to handle the plot with matplotlib. WDYT @ogrisel ?
To remove the remaining failure, we would need to bump from pandas 1.0 to 1.5.
I was wrong. We only need to bump to 1.1. So I assume this is enough no need to change the plot.
Given the constraints for review, what do you think about making TunedThresholdClassifier
experimental for 1.4?
I think that this is good for another round of review.
LGTM
Deja-vu :)
🎉
The issue about refactoring is already open here: https://github.com/scikit-learn/scikit-learn/issues/28941.
For the documentation, I'll open one issue to see how to articulate the constrained part. The comment raised by @amueller is not anymore meaningful since we don't really allow to choose a point on the PR or ROC curve in a straight forward manner.
I'll also revive https://github.com/scikit-learn/scikit-learn/pull/17930
:rocket: @glemaitre 🎉 Thanks for this great addition many have been longing for.
In a way, I like the explicit FixedThresholdClassifier
. Out of curiosity: is there a path forward that this could end up as mixing class in every classifier? Without the need to meta-class-wrap it by the user?
Out of curiosity: is there a path forward that this could end up as mixing class in every classifier? Without the need to meta-class-wrap it by the user?
Maybe but that would entail adding at least 2 new constructor parameters to all classifiers in scikit-learn that implement predict_proba
or decision_function
. That's kind of an invasive API change...