skglm icon indicating copy to clipboard operation
skglm copied to clipboard

FEAT - Make GeneralizedLinearEstimator compatible with GridSearchCV

Open mathurinm opened this issue 1 year ago • 0 comments

Because the parameters that we would like to cross validate are parameters of model.penalty, model.datafit or model.solver, we are not comaptible:

from skglm.utils.data import make_correlated_data
from skglm.datafits import Quadratic
from skglm.penalties import L1
from skglm import GeneralizedLinearEstimator

from sklearn.model_selection import GridSearchCV
import numpy as np

X, y,_ = make_correlated_data()

model = GeneralizedLinearEstimator(Quadratic(), L1(alpha=1))
alpha_grid = np.geomspace(1, 1e-2)

cv = GridSearchCV(model, param_grid={"alpha": alpha_grid}, scoring="neg_mean_squared_error").fit(X, y)

gives TypeError: GeneralizedLinearEstimator.__init__() got an unexpected keyword argument 'penalty__alpha'

How could we solve this? Maybe @glemaitre or @agramfort have an idea?

mathurinm avatar May 31 '24 11:05 mathurinm