parsnip icon indicating copy to clipboard operation
parsnip copied to clipboard

Account for possibility of custom objective function in XGBoost `boost_tree()`

Open smingerson opened this issue 3 years ago • 3 comments

Currently, passing a custom objective function causes an error downstream when predicting. This happens in xgb_pred() when using switch() off of the objective (usually a character string) to modify the output of predict.xgb.Booster().

library(xgboost)
library(parsnip)
library(workflows)
mod <- boost_tree("regression") %>% 
  set_engine("xgboost",
             objective = function(preds, dtrain) {
               truth <- as.numeric(getinfo(dtrain, "label"))
               error <- truth - preds
               gradient <- -2 * error
               hess <- rep.int(2, length(preds))
               list(grad = gradient, hess = hess)
             }
             )

dt <- data.frame(x = rnorm(15))
dt$y <- dt$x + rnorm(15, 0, .05)

wf <- workflow() %>% 
  add_model(mod) %>% 
  add_formula(y~x)
fitted <- fit(wf, data = dt)
predict(fitted, new_data = dt)
#> Error in switch(object$params$objective, `binary:logitraw` = stats::binomial()$linkinv(res), : EXPR must be a length 1 vector

smingerson avatar Apr 02 '21 02:04 smingerson

I also see this error when using parsnip::set_engine("xgboost", params = list(eval_metric = "aucpr")) without setting the objective argument. I came across this error after updating parsnip to 0.1.5 from 0.1.4, and tune::tune_grid started failing. (tidymodels and the other individual packages were also updated in that time i.e. {workflows}, {tune}).

this test is passing: https://github.com/tidymodels/parsnip/blob/cb086385a90227eacfce2f06ed58ff2d4e17bb29/tests/testthat/test_boost_tree_xgboost.R#L169

  spec <-
    boost_tree() %>%
    set_engine("xgboost", objective = "reg:pseudohubererror") %>%
    set_mode("regression")

  xgb_fit <- spec %>% fit(mpg ~ ., data = mtcars)

but if objective is not a string (I guess this is the reason for labelling this as a feature request instead of a bug) it fails like in OPs code. Additionally, if one adds anything else to set_engine it fails with the same error -- are the ... all added to the same vector?

library(xgboost)
library(parsnip)
library(workflows)

mod <- boost_tree("classification") %>% 
  set_engine(
    "xgboost", 
    objective = "binary:logistic",
    params = list(eval_metric = "aucpr") # <- added this and changed the data to be a classification problem
  )

dt <- data.frame(
  x = rnorm(15),
  y = rnorm(15) + rnorm(15, 0, .05),
  target = as.factor(rbinom(15, 1, 0.5))
)

wf <- workflow() %>% 
  add_model(mod) %>% 
  add_formula(target ~ x + y)

fitted <- fit(wf, data = dt)
predict(fitted, new_data = dt)
#> Error in switch(object$params$objective, `binary:logitraw` = stats::binomial()$linkinv(res), : EXPR must be a length 1 vector

Created on 2021-04-20 by the reprex package (v2.0.0)

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.0.2 (2020-06-22)
#>  os       macOS  10.16                
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       Europe/Berlin               
#>  date     2021-04-20                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  ! package     * version date       lib source        
#>    assertthat    0.2.1   2019-03-21 [1] CRAN (R 4.0.2)
#>  P cli           2.4.0   2021-04-05 [?] CRAN (R 4.0.2)
#>  P codetools     0.2-18  2020-11-04 [3] CRAN (R 4.0.2)
#>  P crayon        1.4.1   2021-02-08 [?] CRAN (R 4.0.2)
#>  P data.table    1.14.0  2021-02-21 [?] CRAN (R 4.0.2)
#>  P DBI           1.1.1   2021-01-15 [?] CRAN (R 4.0.2)
#>    digest        0.6.27  2020-10-24 [1] CRAN (R 4.0.2)
#>  P dplyr         1.0.5   2021-03-05 [?] CRAN (R 4.0.2)
#>    ellipsis      0.3.1   2020-05-15 [1] CRAN (R 4.0.2)
#>  P evaluate      0.14    2019-05-28 [?] CRAN (R 4.0.0)
#>  P fansi         0.4.2   2021-01-15 [?] CRAN (R 4.0.2)
#>  P fs            1.5.0   2020-07-31 [?] CRAN (R 4.0.2)
#>  P generics      0.1.0   2020-10-31 [?] CRAN (R 4.0.2)
#>    globals       0.14.0  2020-11-22 [1] CRAN (R 4.0.2)
#>    glue          1.4.2   2020-08-27 [1] CRAN (R 4.0.2)
#>  P hardhat       0.1.5   2020-11-09 [?] CRAN (R 4.0.2)
#>  P highr         0.9     2021-04-16 [?] CRAN (R 4.0.2)
#>  P htmltools     0.5.1.1 2021-01-22 [?] CRAN (R 4.0.2)
#>  P knitr         1.32    2021-04-14 [?] CRAN (R 4.0.2)
#>  P lattice       0.20-41 2020-04-02 [3] CRAN (R 4.0.2)
#>  P lifecycle     1.0.0   2021-02-15 [?] CRAN (R 4.0.2)
#>    magrittr      2.0.1   2020-11-17 [1] CRAN (R 4.0.2)
#>  P Matrix        1.3-2   2021-01-06 [?] CRAN (R 4.0.2)
#>  P parsnip     * 0.1.5   2021-01-19 [?] CRAN (R 4.0.2)
#>  P pillar        1.6.0   2021-04-13 [?] CRAN (R 4.0.2)
#>    pkgconfig     2.0.3   2019-09-22 [1] CRAN (R 4.0.2)
#>    purrr         0.3.4   2020-04-17 [1] CRAN (R 4.0.2)
#>    R6            2.5.0   2020-10-28 [1] CRAN (R 4.0.2)
#>  P reprex        2.0.0   2021-04-02 [?] CRAN (R 4.0.2)
#>  P rlang         0.4.10  2020-12-30 [?] CRAN (R 4.0.2)
#>  P rmarkdown     2.7     2021-02-19 [?] CRAN (R 4.0.2)
#>    rstudioapi    0.13    2020-11-12 [1] CRAN (R 4.0.2)
#>    sessioninfo   1.1.1   2018-11-05 [3] CRAN (R 4.0.2)
#>    stringi       1.5.3   2020-09-09 [1] CRAN (R 4.0.2)
#>    stringr       1.4.0   2019-02-10 [1] CRAN (R 4.0.2)
#>  P tibble        3.1.1   2021-04-18 [?] CRAN (R 4.0.2)
#>  P tidyr         1.1.3   2021-03-03 [?] CRAN (R 4.0.2)
#>    tidyselect    1.1.0   2020-05-11 [1] CRAN (R 4.0.2)
#>  P utf8          1.2.1   2021-03-12 [?] CRAN (R 4.0.2)
#>  P vctrs         0.3.7   2021-03-29 [?] CRAN (R 4.0.2)
#>  P withr         2.4.2   2021-04-18 [?] CRAN (R 4.0.2)
#>  P workflows   * 0.2.2   2021-03-10 [?] CRAN (R 4.0.2)
#>  P xfun          0.22    2021-03-11 [?] CRAN (R 4.0.2)
#>    xgboost     * 1.3.2.1 2021-01-18 [1] CRAN (R 4.0.2)
#>    yaml          2.2.1   2020-02-01 [1] CRAN (R 4.0.2)
#> 
#> [1] /Users/santiago/code/ds-models-fraud/renv/library/R-4.0/x86_64-apple-darwin17.0
#> [2] /private/var/folders/8d/zxgx1qkx44n7_wp6crx3ycsh0000gn/T/Rtmp6h44Di/renv-system-library
#> [3] /Library/Frameworks/R.framework/Versions/4.0/Resources/library
#> 
#>  P ── Loaded and on-disk path mismatch.

To fix it I had to change my code to:

mod <- boost_tree("classification") %>% 
  set_engine(
    "xgboost",
    params = list(
      eval_metric = "aucpr",
      objective = "binary:logistic" # <- MUST be present
    )
  )

the objective must be explicitly declared if params is used, otherwise object$params$objective is NULL. Not sure if this is expected behavior i.e. the default was dropped.

jcpsantiago avatar Apr 20 '21 09:04 jcpsantiago

There are similar reports here as well.

https://github.com/tidymodels/butcher/issues/214

amazongodman avatar Apr 19 '22 23:04 amazongodman

Related to #774.

simonpcouch avatar Aug 09 '22 14:08 simonpcouch

This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.

github-actions[bot] avatar Sep 01 '22 01:09 github-actions[bot]