parsnip
parsnip copied to clipboard
Account for possibility of custom objective function in XGBoost `boost_tree()`
Currently, passing a custom objective function causes an error downstream when predicting. This happens in xgb_pred()
when using switch()
off of the objective (usually a character string) to modify the output of predict.xgb.Booster()
.
library(xgboost)
library(parsnip)
library(workflows)
mod <- boost_tree("regression") %>%
set_engine("xgboost",
objective = function(preds, dtrain) {
truth <- as.numeric(getinfo(dtrain, "label"))
error <- truth - preds
gradient <- -2 * error
hess <- rep.int(2, length(preds))
list(grad = gradient, hess = hess)
}
)
dt <- data.frame(x = rnorm(15))
dt$y <- dt$x + rnorm(15, 0, .05)
wf <- workflow() %>%
add_model(mod) %>%
add_formula(y~x)
fitted <- fit(wf, data = dt)
predict(fitted, new_data = dt)
#> Error in switch(object$params$objective, `binary:logitraw` = stats::binomial()$linkinv(res), : EXPR must be a length 1 vector
I also see this error when using parsnip::set_engine("xgboost", params = list(eval_metric = "aucpr"))
without setting the objective
argument. I came across this error after updating parsnip to 0.1.5 from 0.1.4, and tune::tune_grid
started failing. (tidymodels
and the other individual packages were also updated in that time i.e. {workflows}
, {tune}
).
this test is passing: https://github.com/tidymodels/parsnip/blob/cb086385a90227eacfce2f06ed58ff2d4e17bb29/tests/testthat/test_boost_tree_xgboost.R#L169
spec <-
boost_tree() %>%
set_engine("xgboost", objective = "reg:pseudohubererror") %>%
set_mode("regression")
xgb_fit <- spec %>% fit(mpg ~ ., data = mtcars)
but if objective
is not a string (I guess this is the reason for labelling this as a feature request instead of a bug) it fails like in OPs code. Additionally, if one adds anything else to set_engine
it fails with the same error -- are the ...
all added to the same vector?
library(xgboost)
library(parsnip)
library(workflows)
mod <- boost_tree("classification") %>%
set_engine(
"xgboost",
objective = "binary:logistic",
params = list(eval_metric = "aucpr") # <- added this and changed the data to be a classification problem
)
dt <- data.frame(
x = rnorm(15),
y = rnorm(15) + rnorm(15, 0, .05),
target = as.factor(rbinom(15, 1, 0.5))
)
wf <- workflow() %>%
add_model(mod) %>%
add_formula(target ~ x + y)
fitted <- fit(wf, data = dt)
predict(fitted, new_data = dt)
#> Error in switch(object$params$objective, `binary:logitraw` = stats::binomial()$linkinv(res), : EXPR must be a length 1 vector
Created on 2021-04-20 by the reprex package (v2.0.0)
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.0.2 (2020-06-22)
#> os macOS 10.16
#> system x86_64, darwin17.0
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Europe/Berlin
#> date 2021-04-20
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> ! package * version date lib source
#> assertthat 0.2.1 2019-03-21 [1] CRAN (R 4.0.2)
#> P cli 2.4.0 2021-04-05 [?] CRAN (R 4.0.2)
#> P codetools 0.2-18 2020-11-04 [3] CRAN (R 4.0.2)
#> P crayon 1.4.1 2021-02-08 [?] CRAN (R 4.0.2)
#> P data.table 1.14.0 2021-02-21 [?] CRAN (R 4.0.2)
#> P DBI 1.1.1 2021-01-15 [?] CRAN (R 4.0.2)
#> digest 0.6.27 2020-10-24 [1] CRAN (R 4.0.2)
#> P dplyr 1.0.5 2021-03-05 [?] CRAN (R 4.0.2)
#> ellipsis 0.3.1 2020-05-15 [1] CRAN (R 4.0.2)
#> P evaluate 0.14 2019-05-28 [?] CRAN (R 4.0.0)
#> P fansi 0.4.2 2021-01-15 [?] CRAN (R 4.0.2)
#> P fs 1.5.0 2020-07-31 [?] CRAN (R 4.0.2)
#> P generics 0.1.0 2020-10-31 [?] CRAN (R 4.0.2)
#> globals 0.14.0 2020-11-22 [1] CRAN (R 4.0.2)
#> glue 1.4.2 2020-08-27 [1] CRAN (R 4.0.2)
#> P hardhat 0.1.5 2020-11-09 [?] CRAN (R 4.0.2)
#> P highr 0.9 2021-04-16 [?] CRAN (R 4.0.2)
#> P htmltools 0.5.1.1 2021-01-22 [?] CRAN (R 4.0.2)
#> P knitr 1.32 2021-04-14 [?] CRAN (R 4.0.2)
#> P lattice 0.20-41 2020-04-02 [3] CRAN (R 4.0.2)
#> P lifecycle 1.0.0 2021-02-15 [?] CRAN (R 4.0.2)
#> magrittr 2.0.1 2020-11-17 [1] CRAN (R 4.0.2)
#> P Matrix 1.3-2 2021-01-06 [?] CRAN (R 4.0.2)
#> P parsnip * 0.1.5 2021-01-19 [?] CRAN (R 4.0.2)
#> P pillar 1.6.0 2021-04-13 [?] CRAN (R 4.0.2)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.0.2)
#> purrr 0.3.4 2020-04-17 [1] CRAN (R 4.0.2)
#> R6 2.5.0 2020-10-28 [1] CRAN (R 4.0.2)
#> P reprex 2.0.0 2021-04-02 [?] CRAN (R 4.0.2)
#> P rlang 0.4.10 2020-12-30 [?] CRAN (R 4.0.2)
#> P rmarkdown 2.7 2021-02-19 [?] CRAN (R 4.0.2)
#> rstudioapi 0.13 2020-11-12 [1] CRAN (R 4.0.2)
#> sessioninfo 1.1.1 2018-11-05 [3] CRAN (R 4.0.2)
#> stringi 1.5.3 2020-09-09 [1] CRAN (R 4.0.2)
#> stringr 1.4.0 2019-02-10 [1] CRAN (R 4.0.2)
#> P tibble 3.1.1 2021-04-18 [?] CRAN (R 4.0.2)
#> P tidyr 1.1.3 2021-03-03 [?] CRAN (R 4.0.2)
#> tidyselect 1.1.0 2020-05-11 [1] CRAN (R 4.0.2)
#> P utf8 1.2.1 2021-03-12 [?] CRAN (R 4.0.2)
#> P vctrs 0.3.7 2021-03-29 [?] CRAN (R 4.0.2)
#> P withr 2.4.2 2021-04-18 [?] CRAN (R 4.0.2)
#> P workflows * 0.2.2 2021-03-10 [?] CRAN (R 4.0.2)
#> P xfun 0.22 2021-03-11 [?] CRAN (R 4.0.2)
#> xgboost * 1.3.2.1 2021-01-18 [1] CRAN (R 4.0.2)
#> yaml 2.2.1 2020-02-01 [1] CRAN (R 4.0.2)
#>
#> [1] /Users/santiago/code/ds-models-fraud/renv/library/R-4.0/x86_64-apple-darwin17.0
#> [2] /private/var/folders/8d/zxgx1qkx44n7_wp6crx3ycsh0000gn/T/Rtmp6h44Di/renv-system-library
#> [3] /Library/Frameworks/R.framework/Versions/4.0/Resources/library
#>
#> P ── Loaded and on-disk path mismatch.
To fix it I had to change my code to:
mod <- boost_tree("classification") %>%
set_engine(
"xgboost",
params = list(
eval_metric = "aucpr",
objective = "binary:logistic" # <- MUST be present
)
)
the objective must be explicitly declared if params
is used, otherwise object$params$objective
is NULL
. Not sure if this is expected behavior i.e. the default was dropped.
There are similar reports here as well.
https://github.com/tidymodels/butcher/issues/214
Related to #774.
This issue has been automatically locked. If you believe you have found a related problem, please file a new issue (with a reprex: https://reprex.tidyverse.org) and link to this issue.