tune-sklearn
tune-sklearn copied to clipboard
context is not passed with `set_config`
From SLEP018 on, scikit-learn
has released a global context setter with a simple set_config
API. One of the use case is to propagate the transformed values through the pipeline as a Pandas DataFrame (with set_config(transform_output="pandas")
).
I got this issue trying to run a HPO using TuneGridSearchCV
which does not preserve the context set previously. A simple replication can be done using:
model.py
import sklearn
from sklearn.pipeline import Pipeline
from sklearn.linear_model import Ridge
from sklearn.preprocessing import StandardScaler
from sklearn.feature_selection import VarianceThreshold
# This sets the global context.
sklearn.set_config(transform_output="pandas")
class CustomVarianceThreshold(VarianceThreshold):
def fit(self, X, y = None):
assert X.columns is not None
return super().fit(X, y)
def transform(self, X):
assert X.columns is not None
return super().transform(X)
def MODEL():
return Pipeline([
("scaler", StandardScaler()),
("selector", CustomVarianceThreshold()),
("regressor", Ridge()),
])
main.py
import numpy as np
import sys
from sklearn.datasets import load_iris
from sklearn.model_selection import GridSearchCV
from tune_sklearn import TuneGridSearchCV
from model import MODEL
PARAMS = {
"regressor__alpha": np.linspace(0, 1, 100)
}
# This works fine.
def run_single():
model = MODEL()
data = load_iris(as_frame=True)
model.fit(data.data, data.target)
model.predict(data.data)
# This also works fine. HPO is enabled, but Ray is not the backend.
def run_hpo():
model = MODEL()
cv = GridSearchCV(model, PARAMS, n_jobs=-1)
data = load_iris(as_frame=True)
cv.fit(data.data, data.target)
best_model = cv.best_estimator_
best_model.predict(data.data)
# This breaks because `columns` is not an attribute of X.
def run_hpo_with_ray():
model = MODEL()
cv = TuneGridSearchCV(model, PARAMS, n_jobs=-1)
data = load_iris(as_frame=True)
cv.fit(data.data, data.target)
best_model = cv.best_estimator_
best_model.predict(data.data)
Although there is a workaround which is to define model with extra method (replacing StandardScaler()
with StandardScaler().set_output(transform="pandas")
), I think it would be nice if a global context setting via set_config
integrates well with tune_sklearn
.
I have checked the issues but did not find any pre-existing issue/documentation. Please let me know if this is a duplicate and I apologize if that is the case.