parsnip icon indicating copy to clipboard operation
parsnip copied to clipboard

Draft PR to allow dataframe as `validation` arg in `xgboost`

Open joeycouse opened this issue 2 years ago • 0 comments

Just a draft of allowing user to pass a dataframe as the validation arg for xgboost. From what I can tell, the main issue that prevents from passing a df with extra columns is the fact that when y is passed to as_xgb_data(): https://github.com/tidymodels/parsnip/blob/f8505bdf6f039076c1eca5aa37a103a6e072887d/R/boost_tree.R#L380

The y vector is unnamed, so you can't just select the columns from x and which column corresponds to y to pass to xgb.DMatrix()

Also case weights weren't being passed to the internal validation set in the current implementation.

Would be happy to further develop this PR with some advice from your team

Examples

> # Works 
> reg_fit <-
+   boost_tree(trees = 10, mode = "regression") %>%
+   set_engine("xgboost", 
+              eval_metric = "mae", 
+              validation = mtcars[1:3,],
+              verbose = 1) %>%
+   fit(mpg ~ ., data = mtcars)
[1]	validation-mae:14.428947 
[2]	validation-mae:10.406661 
[3]	validation-mae:7.885722 
[4]	validation-mae:5.953925 
[5]	validation-mae:4.558322 
[6]	validation-mae:3.400241 
[7]	validation-mae:2.560631 
[8]	validation-mae:1.757373 
[9]	validation-mae:1.187967 
[10]	validation-mae:0.787739 
> 
> reg_fit <-
+   boost_tree(trees = 10, mode = "regression") %>%
+   set_engine("xgboost", 
+              eval_metric = "mae", 
+              validation = 0.2,
+              verbose = 1) %>%
+   fit(mpg ~ ., data = mtcars)
[1]	validation-mae:14.203231 
[2]	validation-mae:9.799767 
[3]	validation-mae:7.313254 
[4]	validation-mae:5.302311 
[5]	validation-mae:3.932890 
[6]	validation-mae:2.969718 
[7]	validation-mae:2.525703 
[8]	validation-mae:2.261265 
[9]	validation-mae:2.169489 
[10]	validation-mae:2.301616 
> 
> 
> # Errors
> car_rec <- 
+   recipe(mpg ~ disp + hp + cyl, data = mtcars) |> 
+   update_role(cyl, new_role = 'id')
> 
> 
> reg_fit <-
+     boost_tree(trees = 10, mode = "regression") %>%
+     set_engine("xgboost", 
+                eval_metric = "mae", 
+                validation = mtcars[,c(1,2,3,4)],
+                verbose = 1) 
> 
> result <- 
+   workflows::workflow(car_rec, reg_fit) |> 
+   fit(data = mtcars)
Error in `parsnip::xgb_train()`:
! `validation` should contain 3 columns
Run `rlang::last_error()` to see where the error occurred.
> 
> mtcars_random <-
+   mtcars |> 
+   mutate(random = runif(nrow(mtcars), 0, 10))
> 
> reg_fit <-
+   boost_tree(trees = 10, mode = "regression") %>%
+   set_engine("xgboost", 
+              eval_metric = "mae", 
+              validation = mtcars_random,
+              verbose = 1) %>%
+   fit(mpg ~ ., data = mtcars)
Error in `parsnip::xgb_train()`:
! `validation` should contain 11 columns
Run `rlang::last_error()` to see where the error occurred.

joeycouse avatar Jul 20 '22 18:07 joeycouse