xgboost
xgboost copied to clipboard
xgb.cv stops earlier when custom eval_metric is provided (R)
Hi! I am cross-validating a multinomial classification model using early stopping and I observe that when I provide custom eval_metric it stops earlier.
See below a dummy example on iris, in which the custom metric is mlogloss. When I calculate the mlogloss through my custom function, it stops earlier (mean cv$best_iteration is 6 over 100 runs). When I use eval_metric = "mlogloss" it stops later (mean cv$best_iteration is 86 over 100 runs). There is no real difference in test performance on the iris dataset (if I train with the "winning" nround), but there is a serious difference on my dataset.
Of course, ultimately this is not what I really want to do. I want fit a model using multinomial objective but cross-validate using binary logloss for my class of interest. And in that case, I get exactly the same behaviour: when I provide custom eval_metric, it stops very early.
Any suggestions or explanation on what is happening are welcome. Thanks!
library(xgboost)
library(ModelMetrics)
library(tidyverse)
data(iris)
set.seed(2024)
cv_best_iter <- tibble(i = integer(),
cv_type = character(),
best_iter = integer())
test_auc <- tibble(i = integer(),
cv_type = character(),
outcome = integer(),
auc = double())
for (i in 1:100){
print(i)
N <- nrow(iris)
p_test <- 1/3
id_test <- sample(1:N, round(N*p_test))
iris_test <- iris[id_test,]
iris_train <- iris[-id_test,]
objective <- "multi:softprob"
train_y_label <- as.integer(iris_train$Species)-1
data_train_xgb <- xgb.DMatrix(data = as.matrix(iris_train[,1:4]),
label = train_y_label)
n_classes <- length(levels(iris_train$Species))
logloss_m_obj <- function(preds, dtrain) {
labels <- getinfo(dtrain, "label")
labels <- factor(labels)
# preds should be a matrix
preds <- matrix(preds, ncol = n_classes, byrow = TRUE)
m_logloss <- ModelMetrics::mlogLoss(labels, preds)
#m_logloss <- yardstick::mn_log_loss_vec(labels, preds) # gives same results
return(list(metric = "m_logloss", value = m_logloss))
}
# cv with custom function
cv <- xgb.cv(params = list(booster = "gbtree",
eta = 0.1,
objective = objective,
eval_metric = logloss_m_obj,
#eval_metric = "mlogloss",
num_class = n_classes),
data = data_train_xgb,
nround = 10000, # Set this large and use early stopping
nthread = 12, # parallel
nfold = 5,
prediction = TRUE,
showsd = FALSE,
early_stopping_rounds = 25, # If evaluation metric does not improve on out-of-fold sample for x rounds, stop
maximize = FALSE,
verbose = 0)
# plot(cv$evaluation_log$iter, cv$evaluation_log$test_m_logloss_mean)
cv_best_iter <- cv_best_iter %>%
add_row(i = i,
cv_type = "custom_mlogloss",
best_iter = cv$best_iteration)
# sapply(1:3, function(y)
# sapply(cv$folds, function(x) ModelMetrics::auc(ifelse(train_y_label == y-1, 1, 0)[x], cv$pred[x,y]))) %>%
# colMeans()
# train with winning nrounds
xgb_model <- xgboost(params = list(booster = "gbtree",
objective = objective,
eta = 0.1,
num_class = n_classes),
data = data_train_xgb,
nround = cv$best_iteration,
nthread = 12,
verbose = 0)
test_preds <- predict(xgb_model, xgb.DMatrix(data = as.matrix(iris_test[,1:4])))
test_preds_matrix <- matrix(test_preds, ncol = n_classes, byrow = TRUE)
test_aucs <- sapply(1:3, function(y)
ModelMetrics::auc(ifelse(as.integer(iris_test$Species) == y, 1, 0), test_preds_matrix[,y] - 1))
test_auc <- test_auc %>%
add_row(i = i,
cv_type = "custom_mlogloss",
outcome = c(1, 2, 3),
auc = test_aucs)
# CV with mlogloss
cv <- xgb.cv(params = list(booster = "gbtree",
eta = 0.1,
objective = objective,
#eval_metric = logloss_m_obj,
eval_metric = "mlogloss",
num_class = n_classes),
data = data_train_xgb,
nround = 10000, # Set this large and use early stopping
nthread = 12, # parallel
nfold = 5,
prediction = TRUE,
showsd = FALSE,
early_stopping_rounds = 25, # If evaluation metric does not improve on out-of-fold sample for x rounds, stop
maximize = FALSE,
verbose = 0)
# plot(cv$evaluation_log$iter, cv$evaluation_log$test_mlogloss_mean)
#
# sapply(1:3, function(y)
# sapply(cv$folds, function(x) ModelMetrics::auc(ifelse(train_y_label == y-1, 1, 0)[x], cv$pred[x,y]))) %>%
# colMeans()
cv_best_iter <- cv_best_iter %>%
add_row(i = i,
cv_type = "mlogloss",
best_iter = cv$best_iteration)
# train with winning nrounds
xgb_model <- xgboost(params = list(booster = "gbtree",
objective = objective,
eta = 0.1,
num_class = n_classes),
data = data_train_xgb,
nround = cv$best_iteration,
nthread = 12,
verbose = 0)
test_preds <- predict(xgb_model, xgb.DMatrix(data = as.matrix(iris_test[,1:4])))
test_preds_matrix <- matrix(test_preds, ncol = n_classes, byrow = TRUE)
test_aucs <- sapply(1:3, function(y)
ModelMetrics::auc(ifelse(as.integer(iris_test$Species) == y, 1, 0), test_preds_matrix[,y] - 1))
test_auc <- test_auc %>%
add_row(i = i,
cv_type = "mlogloss",
outcome = c(1, 2, 3),
auc = test_aucs)
}
cv_best_iter %>%
ggplot(aes(best_iter)) +
geom_histogram(bins = 50) +
facet_wrap(~cv_type)
test_auc %>%
group_by(outcome, cv_type) %>%
summarise(mean_auc = mean(auc)) %>%
ungroup()
Session info:
> sessionInfo()
R version 4.2.3 (2023-03-15 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19045)
Matrix products: default
locale:
[1] LC_COLLATE=English_Belgium.utf8 LC_CTYPE=English_Belgium.utf8
[3] LC_MONETARY=English_Belgium.utf8 LC_NUMERIC=C
[5] LC_TIME=English_Belgium.utf8
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] lubridate_1.9.2 forcats_1.0.0 stringr_1.5.0 dplyr_1.1.4
[5] purrr_1.0.1 readr_2.1.4 tidyr_1.3.0 tibble_3.2.1
[9] ggplot2_3.4.4 tidyverse_2.0.0 ModelMetrics_1.2.2.2 xgboost_1.7.5.1
loaded via a namespace (and not attached):
[1] Rcpp_1.0.10 lattice_0.20-45 listenv_0.9.0 class_7.3-21
[5] digest_0.6.31 ipred_0.9-14 foreach_1.5.2 utf8_1.2.3
[9] parallelly_1.36.0 R6_2.5.1 plyr_1.8.8 hardhat_1.3.1
[13] stats4_4.2.3 pillar_1.9.0 rlang_1.1.3 caret_6.0-93
[17] rstudioapi_0.15.0 data.table_1.14.8 rpart_4.1.19 Matrix_1.6-1
[21] labeling_0.4.3 splines_4.2.3 gower_1.0.1 munsell_0.5.0
[25] compiler_4.2.3 pkgconfig_2.0.3 globals_0.16.2 nnet_7.3-18
[29] tidyselect_1.2.0 prodlim_2019.11.13 pmcalibration_0.1.0 codetools_0.2-19
[33] fitdistrplus_1.1-11 fansi_1.0.4 future_1.33.0 tzdb_0.3.0
[37] withr_2.5.0 MASS_7.3-58.2 recipes_1.0.10 grid_4.2.3
[41] nlme_3.1-162 jsonlite_1.8.4 gtable_0.3.4 lifecycle_1.0.3
[45] DBI_1.1.3 magrittr_2.0.3 pROC_1.18.0 scales_1.3.0
[49] future.apply_1.11.0 cli_3.6.2 stringi_1.7.12 farver_2.1.1
[53] ROCR_1.0-11 reshape2_1.4.4 doParallel_1.0.17 timeDate_4022.108
[57] ellipsis_0.3.2 generics_0.1.3 vctrs_0.6.5 lava_1.7.2.1
[61] CalibratR_0.1.2 iterators_1.0.14 tools_4.2.3 glue_1.6.2
[65] hms_1.1.2 parallel_4.2.3 survival_3.5-3 timechange_0.2.0
[69] colorspace_2.1-0 splitTools_1.0.1 precrec_0.14.2