parsnip icon indicating copy to clipboard operation
parsnip copied to clipboard

pass weights to xgboost internal validation set

Open joeycouse opened this issue 2 years ago • 2 comments

In response to https://github.com/tidymodels/parsnip/pull/771#issuecomment-1231991055

This pull request passes case weights to the interval validation set of xgboost. This causes test failures here:

https://github.com/tidymodels/parsnip/blob/e1eb30a6f6704bd0a3cf61736ea1d26e1d8eb081/tests/testthat/test_boost_tree_xgboost.R#L409

and here.

https://github.com/tidymodels/parsnip/blob/e1eb30a6f6704bd0a3cf61736ea1d26e1d8eb081/tests/testthat/test_boost_tree_xgboost.R#L449

Seems like the original intention was to not pass case weights to the internal validation set?

joeycouse avatar Aug 30 '22 19:08 joeycouse

This depends on the type of weight:

  • Importance weights only affect the model estimation and supervised recipes steps. They are not used with yardstick functions for calculating measures of model performance.

  • Frequency weights are used for all parts of the preprocessing, model fitting, and performance estimation operations.

(This is form the blog post but it should be better documented in hardhat).

So we should do this but just for frequency weights (which are less likely to be used with boosting.

topepo avatar Aug 31 '22 12:08 topepo

@topepo @simonpcouch I've updated the PR to only pass freq weights to the internal validation set. This is done by delaying the conversion of the weights to numeric/integer till pass to as_xgb_data()

library(parsnip)

freq_weights <- hardhat::frequency_weights(1:32)

mtcar_x <- mtcars[, -1]
mtcar_mat <- as.matrix(mtcar_x)

set.seed(1)
val_freq_wts <- parsnip:::as_xgb_data(mtcar_mat, mtcars$mpg, weights = freq_weights, validation = 1/10)
xgboost::getinfo(val_freq_wts$watchlist$validation, "weight")
#> [1]  3 17 26


imp_wts <- hardhat::importance_weights(1:32)

mtcar_x <- mtcars[, -1]
mtcar_mat <- as.matrix(mtcar_x)

set.seed(1)
val_freq_wts <- parsnip:::as_xgb_data(mtcar_mat, mtcars$mpg, weights = imp_wts, validation = 1/10)
xgboost::getinfo(val_freq_wts$watchlist$validation, "weight")
#> NULL

Created on 2022-09-01 by the reprex package (v2.0.1)

joeycouse avatar Sep 01 '22 16:09 joeycouse