imbalanced-learn icon indicating copy to clipboard operation
imbalanced-learn copied to clipboard

[ENH] Allow fit_resample to receive metadata routed parameters

Open ShimantoRahman opened this issue 11 months ago • 5 comments

Is your feature request related to a problem? Please describe

In cost-sensitive learning, resampling techniques are used to address the asymmetrical importance of data points. These techniques require the amount of resampling to be dependent on instance-specific parameters, such as cost weights associated with individual data points. These cost weights are usually in a cost matrix for each data point $i$:

Actual Positive ($y_i = 1$) Actual Negative ($y_i = 0$)
Predicted Positive ($\hat y_i=1$) $C_{TP_i} $ $C_{FP_i}$
Predicted Negative ($\hat y_i=0$) $C_{FN_i}$ $C_{TN_i}$

Since these cost weights are dependent on the data point, they cannot be predetermined during initialization __init__ but instead must adapt dynamically based on the input data during the fit_resample process.

The current implementation imbalanced-learn Pipeline object does not natively support passing metadata through its fit_resample method. Metadata routing, which would enable instance-dependent parameters to flow seamlessly through the pipeline, is critical for implementing cost-sensitive learning workflows.

Desired workflow (DOES NOT CURRENTLY WORK)

import numpy as np
from imblearn.pipeline import Pipeline
from sklearn import set_config
from sklearn.utils._metadata_requests import MetadataRequest, RequestMethod
from sklearn.base import BaseEstimator
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression


set_config(enable_metadata_routing=True)

class CostSensitiveSampler(BaseEstimator):

    _estimator_type = "sampler"
    __metadata_request__fit_resample = {'cost_matrix': True}

    def __init__(self, random_state=None):
        self.random_state = random_state

    def fit_resample(self, X, y, cost_matrix=None):
        # resample based on cost_matrix
        # ...
        return X, y

    def _get_metadata_request(self):
        routing = MetadataRequest(owner=self.__class__.__name__)
        routing.fit_resample.add_request(param='cost_matrix', alias=True)
        return routing

    set_fit_resample_request = RequestMethod('fit_resample', ['cost_matrix'])

X, y = make_classification()
cost_matrix = np.random.rand(X.shape[0], 2, 2)
pipeline = Pipeline([
    ('sampler', CostSensitiveSampler().set_fit_resample_request(cost_matrix=True)),
    ('model', LogisticRegression())
])
pipeline.fit(X, y, cost_matrix=cost_matrix)

Describe the solution you'd like

From what I understand from the metadata routing implementation of the Pipeline object only a couple of changes have to be made:

  1. the SIMPLE_METHODS constant found here needs to include "fit_resample":
SIMPLE_METHODS = [
            "fit",
            "partial_fit",
            "fit_resample",  # add line here
            "predict",
            "predict_proba",
            "predict_log_proba",
            "decision_function",
            "score",
            "split",
            "transform",
            "inverse_transform",
        ]

Note that this does require imbalanced-learn to redefine the classes and functions which use the SIMPLE_METHODS constant internally. These are now imported from scikit-learn if scikit-learn version 1.4 or higher is installed. These include: MetadataRequest and _MetadataRequester. 2. A method mapping from caller "fit" to callee "fit_resample" has to be added in the get_meta_data_routing(self) method found here and the filter_resample parameter of self._iter method needs be set to False:

def get_metadata_routing(self):
        """Get metadata routing of this object.

        Please check :ref:`User Guide <metadata_routing>` on how the routing
        mechanism works.

        Returns
        -------
        routing : MetadataRouter
            A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
            routing information.
        """
        router = MetadataRouter(owner=self.__class__.__name__)

        # first we add all steps except the last one
        for _, name, trans in self._iter(with_final=False, filter_passthrough=True, filter_resample=False):  # change filter_resample to False
            method_mapping = MethodMapping()
            # fit, fit_predict, and fit_transform call fit_transform if it
            # exists, or else fit and transform
            if hasattr(trans, "fit_transform"):
                (
                    method_mapping.add(caller="fit", callee="fit_transform")
                    .add(caller="fit_transform", callee="fit_transform")
                    .add(caller="fit_predict", callee="fit_transform")
                    .add(caller="fit_resample", callee="fit_transform")
                )
            else:
                (
                    method_mapping.add(caller="fit", callee="fit")
                    .add(caller="fit", callee="transform")
                    .add(caller="fit_transform", callee="fit")
                    .add(caller="fit_transform", callee="transform")
                    .add(caller="fit_predict", callee="fit")
                    .add(caller="fit_predict", callee="transform")
                    .add(caller="fit_resample", callee="fit")
                    .add(caller="fit_resample", callee="transform")
                )

            (
                method_mapping.add(caller="predict", callee="transform")
                .add(caller="predict", callee="transform")
                .add(caller="predict_proba", callee="transform")
                .add(caller="decision_function", callee="transform")
                .add(caller="predict_log_proba", callee="transform")
                .add(caller="transform", callee="transform")
                .add(caller="inverse_transform", callee="inverse_transform")
                .add(caller="score", callee="transform")
                .add(caller="fit_resample", callee="transform")
                .add(caller="fit", callee="fit_resample")  # add this line
            )

            
            router.add(method_mapping=method_mapping, **{name: trans})
        # add final estimator method mapping
        ...

Additional context

I am a PhD Researcher and used these methods for my paper and the author of a python package Empulse which has implemented samplers which require cost parameters to be passed to the fit_resample method like in the dummy example (see Empulse/Samplers). I find the whole metadata routing implementation incredibly confusing, so apologies if I made some mistakes in my reasoning.

ShimantoRahman avatar Dec 16 '24 13:12 ShimantoRahman

On the principle, I think it would be nice to accept metadata indeed.

For your specific use case, I'm not sure that resampling is actually the best. While working on the scikit-learn project, we found that resampling is breaking the calibration of the classifier and usually what users try actually to solved can be done as a post-tuning of the threshold of the classifier.

We recently added the TunedThresholdClassifier in scikit-learn and we show an example of cost-sensitive learning in the documentation: https://scikit-learn.org/1.5/auto_examples/model_selection/plot_cost_sensitive_learning.html#sphx-glr-auto-examples-model-selection-plot-cost-sensitive-learning-py

We also worked on the following tutorial to show some internal that could be interested to you: https://probabl-ai.github.io/calibration-cost-sensitive-learning/intro.html

glemaitre avatar Dec 18 '24 21:12 glemaitre

So I think that we had an underlying bug in get_metadata_routing.

I got your example and made a minimal reproducer:

https://github.com/scikit-learn-contrib/imbalanced-learn/pull/1115/files#diff-82b96c4de3880642afa90f01a32ca3b1dbac2918d037990a7826ba4dc206a939R1501-R1519

So it means that it should work out of the box.

glemaitre avatar Dec 20 '24 15:12 glemaitre

For your specific use case, I'm not sure that resampling is actually the best. While working on the scikit-learn project, we found that resampling is breaking the calibration of the classifier and usually what users try actually to solved can be done as a post-tuning of the threshold of the classifier.

Thank you for your recommendation. A couple of days ago I had watched your podcast together with Vincent Warmerdam on the Probabl YouTube channel. It was quite insightful and it prompted me to read the scikit-learn documentation you have linked above. It was very insightful and definitely changed my perspective to the problem. I was planning to do some benchmarking of my own once I was finished implementing some of the techniques I found in literature, and I will definitely explore calibration further.

I just tested out version 0.13.0 and it works like a charm! Thank you for the quick implementation and my best wishes this holiday period <3

ShimantoRahman avatar Dec 22 '24 13:12 ShimantoRahman

One small suggestion in relation to type checking. As of now type checkers will not recognize the set_fit_resample_request method as it is dynamically constructed at runtime. Perhaps adding this to the SamplerMixin could be useful:

class SamplerMixin(metaclass=ABCMeta):
    """Mixin class for samplers with abstract method.

    Warning: This class should not be used directly. Use the derive classes
    instead.
    """

    _estimator_type = "sampler"

    if TYPE_CHECKING:
        def set_fit_resample_request(self, **kwargs): pass

    ...

ShimantoRahman avatar Dec 22 '24 13:12 ShimantoRahman

Let me reopen to not forget about this last issue. Thanks for reporting.

glemaitre avatar Dec 22 '24 14:12 glemaitre