parsnip
parsnip copied to clipboard
Draft PR to allow dataframe as `validation` arg in `xgboost`
Just a draft of allowing user to pass a dataframe as the validation
arg for xgboost
. From what I can tell, the main issue that prevents from passing a df with extra columns is the fact that when y
is passed to as_xgb_data()
:
https://github.com/tidymodels/parsnip/blob/f8505bdf6f039076c1eca5aa37a103a6e072887d/R/boost_tree.R#L380
The y
vector is unnamed, so you can't just select the columns from x
and which column corresponds to y
to pass to
xgb.DMatrix()
Also case weights weren't being passed to the internal validation set in the current implementation.
Would be happy to further develop this PR with some advice from your team
Examples
> # Works
> reg_fit <-
+ boost_tree(trees = 10, mode = "regression") %>%
+ set_engine("xgboost",
+ eval_metric = "mae",
+ validation = mtcars[1:3,],
+ verbose = 1) %>%
+ fit(mpg ~ ., data = mtcars)
[1] validation-mae:14.428947
[2] validation-mae:10.406661
[3] validation-mae:7.885722
[4] validation-mae:5.953925
[5] validation-mae:4.558322
[6] validation-mae:3.400241
[7] validation-mae:2.560631
[8] validation-mae:1.757373
[9] validation-mae:1.187967
[10] validation-mae:0.787739
>
> reg_fit <-
+ boost_tree(trees = 10, mode = "regression") %>%
+ set_engine("xgboost",
+ eval_metric = "mae",
+ validation = 0.2,
+ verbose = 1) %>%
+ fit(mpg ~ ., data = mtcars)
[1] validation-mae:14.203231
[2] validation-mae:9.799767
[3] validation-mae:7.313254
[4] validation-mae:5.302311
[5] validation-mae:3.932890
[6] validation-mae:2.969718
[7] validation-mae:2.525703
[8] validation-mae:2.261265
[9] validation-mae:2.169489
[10] validation-mae:2.301616
>
>
> # Errors
> car_rec <-
+ recipe(mpg ~ disp + hp + cyl, data = mtcars) |>
+ update_role(cyl, new_role = 'id')
>
>
> reg_fit <-
+ boost_tree(trees = 10, mode = "regression") %>%
+ set_engine("xgboost",
+ eval_metric = "mae",
+ validation = mtcars[,c(1,2,3,4)],
+ verbose = 1)
>
> result <-
+ workflows::workflow(car_rec, reg_fit) |>
+ fit(data = mtcars)
Error in `parsnip::xgb_train()`:
! `validation` should contain 3 columns
Run `rlang::last_error()` to see where the error occurred.
>
> mtcars_random <-
+ mtcars |>
+ mutate(random = runif(nrow(mtcars), 0, 10))
>
> reg_fit <-
+ boost_tree(trees = 10, mode = "regression") %>%
+ set_engine("xgboost",
+ eval_metric = "mae",
+ validation = mtcars_random,
+ verbose = 1) %>%
+ fit(mpg ~ ., data = mtcars)
Error in `parsnip::xgb_train()`:
! `validation` should contain 11 columns
Run `rlang::last_error()` to see where the error occurred.