xgboost icon indicating copy to clipboard operation
xgboost copied to clipboard

xgb.cv stops earlier when custom eval_metric is provided (R)

Open sibipx opened this issue 7 months ago • 3 comments

Hi! I am cross-validating a multinomial classification model using early stopping and I observe that when I provide custom eval_metric it stops earlier.

See below a dummy example on iris, in which the custom metric is mlogloss. When I calculate the mlogloss through my custom function, it stops earlier (mean cv$best_iteration is 6 over 100 runs). When I use eval_metric = "mlogloss" it stops later (mean cv$best_iteration is 86 over 100 runs). There is no real difference in test performance on the iris dataset (if I train with the "winning" nround), but there is a serious difference on my dataset.

Of course, ultimately this is not what I really want to do. I want fit a model using multinomial objective but cross-validate using binary logloss for my class of interest. And in that case, I get exactly the same behaviour: when I provide custom eval_metric, it stops very early.

Any suggestions or explanation on what is happening are welcome. Thanks!

library(xgboost)
library(ModelMetrics)
library(tidyverse)

data(iris)

set.seed(2024)

cv_best_iter <- tibble(i = integer(),
                       cv_type = character(),
                       best_iter = integer())

test_auc <- tibble(i = integer(),
                   cv_type = character(),
                   outcome = integer(),
                   auc = double())

for (i in 1:100){
  print(i)
  
  N <- nrow(iris)
  p_test <- 1/3
  id_test <- sample(1:N, round(N*p_test))
  
  iris_test <- iris[id_test,]
  iris_train <- iris[-id_test,]
  
  objective <- "multi:softprob"
  
  train_y_label <- as.integer(iris_train$Species)-1
  
  data_train_xgb <- xgb.DMatrix(data = as.matrix(iris_train[,1:4]),
                                label = train_y_label)
  
  n_classes <- length(levels(iris_train$Species))
  
  logloss_m_obj <- function(preds, dtrain) {
    labels <- getinfo(dtrain, "label")
    labels <- factor(labels)
    
    # preds should be a matrix 
    preds <- matrix(preds, ncol = n_classes, byrow = TRUE)
    
    m_logloss <- ModelMetrics::mlogLoss(labels, preds)
    #m_logloss <- yardstick::mn_log_loss_vec(labels, preds) # gives same results
    
    return(list(metric = "m_logloss", value = m_logloss))
  }
  
  # cv with custom function
  cv <- xgb.cv(params = list(booster          = "gbtree",
                             eta              = 0.1,
                             objective        = objective, 
                             eval_metric = logloss_m_obj,
                             #eval_metric = "mlogloss",
                             num_class = n_classes), 
               data = data_train_xgb, 
               nround = 10000, # Set this large and use early stopping
               nthread = 12, # parallel
               nfold =  5,
               prediction = TRUE,
               showsd = FALSE,
               early_stopping_rounds = 25, # If evaluation metric does not improve on out-of-fold sample for x rounds, stop
               maximize = FALSE,
               verbose = 0)
  
  # plot(cv$evaluation_log$iter, cv$evaluation_log$test_m_logloss_mean)
  
  cv_best_iter <- cv_best_iter %>% 
    add_row(i = i,
            cv_type = "custom_mlogloss",
            best_iter = cv$best_iteration)
  
  # sapply(1:3, function(y) 
  #   sapply(cv$folds, function(x) ModelMetrics::auc(ifelse(train_y_label == y-1, 1, 0)[x], cv$pred[x,y]))) %>% 
  #   colMeans()
  
  # train with winning nrounds
  xgb_model <- xgboost(params = list(booster = "gbtree",
                                     objective = objective,
                                     eta = 0.1,
                                     num_class = n_classes),
                       data = data_train_xgb, 
                       nround = cv$best_iteration, 
                       nthread = 12,
                       verbose = 0)
  
  test_preds <- predict(xgb_model, xgb.DMatrix(data = as.matrix(iris_test[,1:4])))
  test_preds_matrix <- matrix(test_preds, ncol = n_classes, byrow = TRUE)
  
  
  test_aucs <- sapply(1:3, function(y) 
    ModelMetrics::auc(ifelse(as.integer(iris_test$Species) == y, 1, 0), test_preds_matrix[,y] - 1))
  
  test_auc <- test_auc %>% 
    add_row(i = i,
            cv_type = "custom_mlogloss",
            outcome = c(1, 2, 3),
            auc = test_aucs)
  
  # CV with mlogloss
  cv <- xgb.cv(params = list(booster          = "gbtree",
                             eta              = 0.1,
                             objective        = objective, 
                             #eval_metric = logloss_m_obj,
                             eval_metric = "mlogloss",
                             num_class = n_classes), 
               data = data_train_xgb, 
               nround = 10000, # Set this large and use early stopping
               nthread = 12, # parallel
               nfold =  5,
               prediction = TRUE,
               showsd = FALSE,
               early_stopping_rounds = 25, # If evaluation metric does not improve on out-of-fold sample for x rounds, stop
               maximize = FALSE,
               verbose = 0)
  
  # plot(cv$evaluation_log$iter, cv$evaluation_log$test_mlogloss_mean)
  # 
  # sapply(1:3, function(y) 
  #   sapply(cv$folds, function(x) ModelMetrics::auc(ifelse(train_y_label == y-1, 1, 0)[x], cv$pred[x,y]))) %>% 
  #   colMeans()
  
  cv_best_iter <- cv_best_iter %>% 
    add_row(i = i,
            cv_type = "mlogloss",
            best_iter = cv$best_iteration)
  
  # train with winning nrounds
  xgb_model <- xgboost(params = list(booster = "gbtree",
                                     objective = objective,
                                     eta = 0.1,
                                     num_class = n_classes),
                       data = data_train_xgb, 
                       nround = cv$best_iteration, 
                       nthread = 12,
                       verbose = 0)
  
  test_preds <- predict(xgb_model, xgb.DMatrix(data = as.matrix(iris_test[,1:4])))
  test_preds_matrix <- matrix(test_preds, ncol = n_classes, byrow = TRUE)
  
  
  test_aucs <- sapply(1:3, function(y) 
    ModelMetrics::auc(ifelse(as.integer(iris_test$Species) == y, 1, 0), test_preds_matrix[,y] - 1))
  
  test_auc <- test_auc %>% 
    add_row(i = i,
            cv_type = "mlogloss",
            outcome = c(1, 2, 3),
            auc = test_aucs)
  
}

cv_best_iter %>% 
  ggplot(aes(best_iter)) +
  geom_histogram(bins = 50) +
  facet_wrap(~cv_type)

test_auc %>% 
  group_by(outcome, cv_type) %>% 
  summarise(mean_auc = mean(auc)) %>% 
  ungroup()

Session info:

> sessionInfo()
R version 4.2.3 (2023-03-15 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19045)

Matrix products: default

locale:
[1] LC_COLLATE=English_Belgium.utf8  LC_CTYPE=English_Belgium.utf8   
[3] LC_MONETARY=English_Belgium.utf8 LC_NUMERIC=C                    
[5] LC_TIME=English_Belgium.utf8    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] lubridate_1.9.2      forcats_1.0.0        stringr_1.5.0        dplyr_1.1.4         
 [5] purrr_1.0.1          readr_2.1.4          tidyr_1.3.0          tibble_3.2.1        
 [9] ggplot2_3.4.4        tidyverse_2.0.0      ModelMetrics_1.2.2.2 xgboost_1.7.5.1     

loaded via a namespace (and not attached):
 [1] Rcpp_1.0.10         lattice_0.20-45     listenv_0.9.0       class_7.3-21       
 [5] digest_0.6.31       ipred_0.9-14        foreach_1.5.2       utf8_1.2.3         
 [9] parallelly_1.36.0   R6_2.5.1            plyr_1.8.8          hardhat_1.3.1      
[13] stats4_4.2.3        pillar_1.9.0        rlang_1.1.3         caret_6.0-93       
[17] rstudioapi_0.15.0   data.table_1.14.8   rpart_4.1.19        Matrix_1.6-1       
[21] labeling_0.4.3      splines_4.2.3       gower_1.0.1         munsell_0.5.0      
[25] compiler_4.2.3      pkgconfig_2.0.3     globals_0.16.2      nnet_7.3-18        
[29] tidyselect_1.2.0    prodlim_2019.11.13  pmcalibration_0.1.0 codetools_0.2-19   
[33] fitdistrplus_1.1-11 fansi_1.0.4         future_1.33.0       tzdb_0.3.0         
[37] withr_2.5.0         MASS_7.3-58.2       recipes_1.0.10      grid_4.2.3         
[41] nlme_3.1-162        jsonlite_1.8.4      gtable_0.3.4        lifecycle_1.0.3    
[45] DBI_1.1.3           magrittr_2.0.3      pROC_1.18.0         scales_1.3.0       
[49] future.apply_1.11.0 cli_3.6.2           stringi_1.7.12      farver_2.1.1       
[53] ROCR_1.0-11         reshape2_1.4.4      doParallel_1.0.17   timeDate_4022.108  
[57] ellipsis_0.3.2      generics_0.1.3      vctrs_0.6.5         lava_1.7.2.1       
[61] CalibratR_0.1.2     iterators_1.0.14    tools_4.2.3         glue_1.6.2         
[65] hms_1.1.2           parallel_4.2.3      survival_3.5-3      timechange_0.2.0   
[69] colorspace_2.1-0    splitTools_1.0.1    precrec_0.14.2  

sibipx avatar Jul 16 '24 09:07 sibipx