bonsai icon indicating copy to clipboard operation
bonsai copied to clipboard

multi_predict() doesn't support `type = "raw"` predictions for `{lightgbm}` classification models

Open jameslamb opened this issue 2 years ago • 1 comments

There is some code in {bonsai} that looks like it was intended to support multi_predict(..., type = "raw") for {lightgbm} classification models.

https://github.com/tidymodels/bonsai/blob/6c090e16f1a5476da1699ff14d8927f92fbe2c83/R/lightgbm_data.R#L146-L158

However, I don't believe {bonsai} actually respects type = "raw" for multi_predict().

Reproducible Example

See the following coded for evidence of this claim. I saw this behavior with both {lightgbm} v3.3.2 installed from CRAN and with the latest development version (https://github.com/microsoft/LightGBM/commit/c7102e56b246cc5cd73d9787b2c837c0bc384d1e).

sessionInfo() (click me)
R version 4.1.0 (2021-05-18)
Platform: x86_64-apple-darwin17.0 (64-bit)
Running under: macOS 12.2.1

Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.1/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] modeldata_1.0.0   lightgbm_3.3.2    R6_2.5.1          dplyr_1.0.9       bonsai_0.1.0.9000
[6] parsnip_1.0.0    

loaded via a namespace (and not attached):
 [1] rstudioapi_0.13   magrittr_2.0.3    tidyselect_1.1.2  munsell_0.5.0     lattice_0.20-45  
 [6] colorspace_2.0-3  rlang_1.0.4       fansi_1.0.3       tools_4.1.0       hardhat_1.2.0    
[11] grid_4.1.0        data.table_1.14.2 gtable_0.3.0      utf8_1.2.2        cli_3.3.0        
[16] withr_2.5.0       ellipsis_0.3.2    tibble_3.1.7      lifecycle_1.0.1   crayon_1.5.1     
[21] Matrix_1.4-0      purrr_0.3.4       ggplot2_3.3.6     tidyr_1.2.0       vctrs_0.4.1      
[26] glue_1.6.2        compiler_4.1.0    pillar_1.8.0      dials_1.0.0       generics_0.1.3   
[31] scales_1.2.0      jsonlite_1.8.0    DiceDesign_1.9    pkgconfig_2.0.3
library(bonsai)
library(dplyr)
library(lightgbm)
library(modeldata)
library(parsnip)

data("penguins", package = "modeldata")
penguins <- penguins[complete.cases(penguins),]

penguins_subset <- penguins[1:10,]
penguins_subset_numeric <-
    penguins_subset %>%
    mutate(across(where(is.character), ~as.factor(.x))) %>%
    mutate(across(where(is.factor), ~as.integer(.x) - 1))

clf_multiclass_fit <-
    boost_tree(trees = 5) %>%
    set_engine("lightgbm") %>%
    set_mode("classification") %>%
    fit(species ~ ., data = penguins)

new_data <-
    penguins_subset_numeric %>%
    select(-species) %>%
    as.matrix()

preds_bonsai_raw <-
    multi_predict(
        clf_multiclass_fit
        , new_data = new_data[1, , drop = FALSE]
        , trees = seq_len(4)
        , type = "raw"
    )

preds_lgb_raw <-
    t(sapply(
        X = seq_len(4)
        , FUN = function(booster, new_data, num_iteration) {
            booster$predict(new_data, num_iteration = num_iteration, rawscore = TRUE)
        }
        , booster = clf_multiclass_fit$fit
        , new_data = new_data[1, , drop = FALSE]
    ))

preds_bonsai_prob <-
    multi_predict(
        clf_multiclass_fit
        , new_data = new_data[1, , drop = FALSE]
        , trees = seq_len(4)
        , type = "prob"
    )

The predictions from multi_predict(..., type = "raw") look like probabilities (between 0 and 1, sum to 1) and don't match {lightgbm}'s output for raw predictions.

preds_bonsai_raw[[".pred"]][[1]]
# A tibble: 4 × 4
#  trees .pred_Adelie .pred_Chinstrap .pred_Gentoo
#  <int>        <dbl>           <dbl>        <dbl>
#      1        0.500           0.184        0.316
#      2        0.556           0.165        0.279
#      3        0.607           0.147        0.246
#      4        0.652           0.131        0.217

preds_lgb_raw
#            [,1]      [,2]      [,3]
# [1,] -0.6724811 -1.672408 -1.132757
# [2,] -0.5392134 -1.754103 -1.230182
# [3,] -0.4193116 -1.834036 -1.322633
# [4,] -0.3093926 -1.912255 -1.411070

type = "prob" predictions look correct, and like probabilities.

preds_bonsai_prob[[".pred"]][[1]]
# A tibble: 4 × 4
#   trees .pred_Adelie .pred_Chinstrap .pred_Gentoo
#  <int>        <dbl>           <dbl>        <dbl>
# 1     1        0.500           0.184        0.316
# 2     2        0.556           0.165        0.279
# 3     3        0.607           0.147        0.246
# 4     4        0.652           0.131        0.217

I observed the same thing for binary classification models. This doesn't matter for regression models, because "raw" predictions are the default for {lightgbm} regression models using built-in objectives.

Notes for Maintainers

I believe the issue is that this block does not contain an if (type == "raw") condition:

https://github.com/tidymodels/bonsai/blob/6c090e16f1a5476da1699ff14d8927f92fbe2c83/R/lightgbm.R#L366-L375

Is it expected that {bonsai} supports multi_predict(..., type = "raw") for {lightgbm} classification models? If so, would you be open to me putting up a pull request to add this support?

Thanks for your time and consideration.

jameslamb avatar Aug 07 '22 04:08 jameslamb

Just wanted to drop a note here and let you know this hasn't fallen off my radar!

I'm hoping to spend some time with our multi_predict methods and put together some more unified machinery for dispatch and testing, and will return to this PR after then.👍

simonpcouch avatar Aug 17 '22 16:08 simonpcouch