bonsai
bonsai copied to clipboard
multi_predict() doesn't support `type = "raw"` predictions for `{lightgbm}` classification models
There is some code in {bonsai}
that looks like it was intended to support multi_predict(..., type = "raw")
for {lightgbm}
classification models.
https://github.com/tidymodels/bonsai/blob/6c090e16f1a5476da1699ff14d8927f92fbe2c83/R/lightgbm_data.R#L146-L158
However, I don't believe {bonsai}
actually respects type = "raw"
for multi_predict()
.
Reproducible Example
See the following coded for evidence of this claim. I saw this behavior with both {lightgbm}
v3.3.2 installed from CRAN and with the latest development version (https://github.com/microsoft/LightGBM/commit/c7102e56b246cc5cd73d9787b2c837c0bc384d1e).
sessionInfo() (click me)
R version 4.1.0 (2021-05-18)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS 12.2.1
Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRlapack.dylib
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] modeldata_1.0.0 lightgbm_3.3.2 R6_2.5.1 dplyr_1.0.9 bonsai_0.1.0.9000
[6] parsnip_1.0.0
loaded via a namespace (and not attached):
[1] rstudioapi_0.13 magrittr_2.0.3 tidyselect_1.1.2 munsell_0.5.0 lattice_0.20-45
[6] colorspace_2.0-3 rlang_1.0.4 fansi_1.0.3 tools_4.1.0 hardhat_1.2.0
[11] grid_4.1.0 data.table_1.14.2 gtable_0.3.0 utf8_1.2.2 cli_3.3.0
[16] withr_2.5.0 ellipsis_0.3.2 tibble_3.1.7 lifecycle_1.0.1 crayon_1.5.1
[21] Matrix_1.4-0 purrr_0.3.4 ggplot2_3.3.6 tidyr_1.2.0 vctrs_0.4.1
[26] glue_1.6.2 compiler_4.1.0 pillar_1.8.0 dials_1.0.0 generics_0.1.3
[31] scales_1.2.0 jsonlite_1.8.0 DiceDesign_1.9 pkgconfig_2.0.3
library(bonsai)
library(dplyr)
library(lightgbm)
library(modeldata)
library(parsnip)
data("penguins", package = "modeldata")
penguins <- penguins[complete.cases(penguins),]
penguins_subset <- penguins[1:10,]
penguins_subset_numeric <-
penguins_subset %>%
mutate(across(where(is.character), ~as.factor(.x))) %>%
mutate(across(where(is.factor), ~as.integer(.x) - 1))
clf_multiclass_fit <-
boost_tree(trees = 5) %>%
set_engine("lightgbm") %>%
set_mode("classification") %>%
fit(species ~ ., data = penguins)
new_data <-
penguins_subset_numeric %>%
select(-species) %>%
as.matrix()
preds_bonsai_raw <-
multi_predict(
clf_multiclass_fit
, new_data = new_data[1, , drop = FALSE]
, trees = seq_len(4)
, type = "raw"
)
preds_lgb_raw <-
t(sapply(
X = seq_len(4)
, FUN = function(booster, new_data, num_iteration) {
booster$predict(new_data, num_iteration = num_iteration, rawscore = TRUE)
}
, booster = clf_multiclass_fit$fit
, new_data = new_data[1, , drop = FALSE]
))
preds_bonsai_prob <-
multi_predict(
clf_multiclass_fit
, new_data = new_data[1, , drop = FALSE]
, trees = seq_len(4)
, type = "prob"
)
The predictions from multi_predict(..., type = "raw")
look like probabilities (between 0 and 1, sum to 1) and don't match {lightgbm}
's output for raw predictions.
preds_bonsai_raw[[".pred"]][[1]]
# A tibble: 4 × 4
# trees .pred_Adelie .pred_Chinstrap .pred_Gentoo
# <int> <dbl> <dbl> <dbl>
# 1 0.500 0.184 0.316
# 2 0.556 0.165 0.279
# 3 0.607 0.147 0.246
# 4 0.652 0.131 0.217
preds_lgb_raw
# [,1] [,2] [,3]
# [1,] -0.6724811 -1.672408 -1.132757
# [2,] -0.5392134 -1.754103 -1.230182
# [3,] -0.4193116 -1.834036 -1.322633
# [4,] -0.3093926 -1.912255 -1.411070
type = "prob"
predictions look correct, and like probabilities.
preds_bonsai_prob[[".pred"]][[1]]
# A tibble: 4 × 4
# trees .pred_Adelie .pred_Chinstrap .pred_Gentoo
# <int> <dbl> <dbl> <dbl>
# 1 1 0.500 0.184 0.316
# 2 2 0.556 0.165 0.279
# 3 3 0.607 0.147 0.246
# 4 4 0.652 0.131 0.217
I observed the same thing for binary classification models. This doesn't matter for regression models, because "raw"
predictions are the default for {lightgbm}
regression models using built-in objectives.
Notes for Maintainers
I believe the issue is that this block does not contain an if (type == "raw")
condition:
https://github.com/tidymodels/bonsai/blob/6c090e16f1a5476da1699ff14d8927f92fbe2c83/R/lightgbm.R#L366-L375
Is it expected that {bonsai}
supports multi_predict(..., type = "raw")
for {lightgbm}
classification models? If so, would you be open to me putting up a pull request to add this support?
Thanks for your time and consideration.
Just wanted to drop a note here and let you know this hasn't fallen off my radar!
I'm hoping to spend some time with our multi_predict
methods and put together some more unified machinery for dispatch and testing, and will return to this PR after then.👍