scikeras
scikeras copied to clipboard
validation_data parameter instead of validation_split?
If we use validation_split, there will be data leakage. Therefore, we need to make sure the validation set is constant by using validation_data parameter, instead of validation_split.
def get_model(hidden_layer_dim, meta):
n_features_in_ = meta["n_features_in_"]
X_shape_ = meta["X_shape_"]
n_classes_ = meta["n_classes_"]
model = tf.keras.models.Sequential()
model.add(
tf.keras.layers.Dense(
units=4, input_shape=X_shape_[1:],
kernel_initializer=tf.keras.initializers.HeNormal(seed=0)
)
)
model.add(
tf.keras.layers.BatchNormalization()
)
model.add(
tf.keras.layers.LeakyReLU(alpha=0.3)
)
model.add(
tf.keras.layers.Dense(units=1, activation="sigmoid")
)
return model
callback = tf.keras.callbacks.EarlyStopping(
monitor="val_f1_score",
mode="max",
patience=10,
min_delta=0.0001,
restore_best_weights=True
)
clf = KerasClassifier(
model=get_model,
loss="binary_crossentropy",
hidden_layer_dim=100,
metrics=tfa.metrics.F1Score,
metrics__num_classes=1,
metrics__average='macro',
metrics__threshold=0.3,
epochs=100,
callbacks=[callback],
validation_split=0.1,
verbose=False
)
def sel_score(X, y):
return mutual_info_classif(
X=X, y=y,
discrete_features=[False, False, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True],
random_state=0
)
pipe = Pipeline(
steps=[
('composer', composer),
('sel', SelectKBest(
score_func=sel_score,
k=17,
)),
('clf', clf)
]
)
params = {
"clf__hidden_layer_dim": [8],
"clf__loss": ["binary_crossentropy"],
"clf__optimizer": ["sgd", "adam"],
"clf__metrics__threshold": [0.3]
}
gs = GridSearchCV(
pipe, params, refit=False, cv=3,
scoring=make_scorer(
score_func=f1_score,
average='macro',
labels=['Yes']
),
n_jobs=-1
)
gs.fit(X_train_val, y_train_val)
Wonder if there is an answer for this thread?
Have you tried passing in fit__validation_data to the .fit() call?
https://adriangb.com/scikeras/stable/advanced.html#routed-parameters
Have you tried passing in
fit__validation_datato the.fit()call?
fit__validation_data is not in the fit() parameter for scikeras. Only fit__validation_split is in.
Closed, old age issue