stacks
stacks copied to clipboard
reduce size of final model_stack for making predictions
After finalizing a model_stack with fit_members(), I get a large object (~13.5 GB) that is inconvenient to use in my specific research production setting. Is there a way to trim the final stack object while still maintaining it's wonderfully simple predict() implementation?
My initial thought was to simply remove the "data_stack" and "splits" (~6.5 GB and ~5.2 GB, respectively, out of the 13.5 GB total). As far as I can tell the splits aren't require in predict.model_stack(), but I see that the data_stack is being called to retrieve the outcome when type = "class".
If I'm only interesting in type = "prob", I can remove the data_stack and splits elements using the proper indexing, and still pass this trimmed object to predict.model_stack() without consequence (I think).
But maybe there is a more elegant solution to this.
I've attached a reprex example below.
# HOUSEKEEPING ####
rm(list = ls(all = TRUE)) # clean house
# CRAN libraries
library(tidyverse) # install.packages("tidyverse")
library(tidymodels) # install.packages("tidymodels")
library(stacks) # install.packages("stacks")
# DEFINE SEED ####
seed <- 42
# DEFINE SET ####
tree_frogs_class_train <- tree_frogs %>%
dplyr::select(-c(clutch, latency))
# SET UP RESAMPLING ####
set.seed(seed)
cv_folds <- tree_frogs_class_train %>% rsample::vfold_cv(v = 5)
# BUILD MODELS ####
mod_svmlinear <- parsnip::svm_linear(cost = tune(), margin = tune()) %>%
parsnip::set_engine("kernlab") %>%
parsnip::set_mode("classification")
mod_elasticnet <- parsnip::logistic_reg(penalty = tune(), mixture = tune()) %>%
parsnip::set_engine("glmnet") %>%
parsnip::set_mode("classification")
# BUILD RECIPE ####
recipe_set <- recipes::recipe(hatched ~ ., data = tree_frogs_class_train) %>%
recipes::step_dummy(recipes::all_nominal(), -hatched) %>%
recipes::step_zv(recipes::all_predictors()) %>%
recipes::step_normalize(recipes::all_numeric())
# CREATE WORKFLOW ####
workflow_set <- workflowsets::workflow_set(
preproc = list(recipe_set),
models = list(
svm_linear = mod_svmlinear,
elasticnet = mod_elasticnet
)
)
# TUNING ####
res_tune <- workflow_set %>%
workflowsets::workflow_map(
seed = seed,
resamples = cv_folds,
fn = "tune_grid",
grid = 25,
metrics = yardstick::metric_set(roc_auc),
control = tune::control_grid(
save_pred = TRUE,
save_workflow = TRUE,
verbose = FALSE,
allow_par = FALSE,
)
)
# DEFINE ENSEMBLE STACK ####
stack_finalized <- stacks::stacks() %>%
stacks::add_candidates(res_tune) %>%
stacks::blend_predictions() %>%
stacks::fit_members()
# CHECK THE ELEMENTS OF THE MODEL_STACK ####
stack_finalized %>% names()
#> [1] "model_defs" "coefs" "penalty" "metrics"
#> [5] "equations" "cols_map" "model_metrics" "train"
#> [9] "mode" "outcome" "data_stack" "splits"
#> [13] "member_fits"
# CHECK THE SIZE OF THE WHOLE MODEL_STACK ####
stack_finalized %>%
object.size() %>%
format("MB")
#> [1] "2 Mb"
# CHECK THE SIZE OF THE MODEL_STACK WITHOUT data_stack AND splits ####
stack_finalized[-c(11, 12)] %>%
object.size() %>%
format("MB")
#> [1] "0.9 Mb"
# TRIM THE MODEL_STACK
stack_finalized_trimmed <- stack_finalized[-c(11, 12)]
# COMPARE THE CLASSES ####
stack_finalized %>% class()
#> [1] "linear_stack" "model_stack" "list"
stack_finalized_trimmed %>% class() # just a list
#> [1] "list"
# MAKE PREDICTIONS WITH WHOLE MODEL_STACK ####
stack_finalized %>%
stacks::predict.model_stack(new_data = tree_frogs_class_train, type = "prob") %>%
head
#> # A tibble: 6 <d7> 2
#> .pred_yes .pred_no
#> <dbl> <dbl>
#> 1 0.934 0.0660
#> 2 0.397 0.603
#> 3 0.602 0.398
#> 4 0.165 0.835
#> 5 0.0815 0.918
#> 6 0.112 0.888
# MAKE PREDICTIONS WITH TRIMMED MODEL_STACK ####
stack_finalized_trimmed %>%
stacks::predict.model_stack(new_data = tree_frogs_class_train, type = "prob") %>%
head
#> # A tibble: 6 <d7> 2
#> .pred_yes .pred_no
#> <dbl> <dbl>
#> 1 0.934 0.0660
#> 2 0.397 0.603
#> 3 0.602 0.398
#> 4 0.165 0.835
#> 5 0.0815 0.918
#> 6 0.112 0.888
# still works!
Created on 2024-05-01 with reprex v2.1.0
Session info
sessioninfo::session_info()
#> - Session info ---------------------------------------------------------------
#> setting value
#> version R version 4.3.3 (2024-02-29 ucrt)
#> os Windows 11 x64 (build 22631)
#> system x86_64, mingw32
#> ui RTerm
#> language (EN)
#> collate English_Canada.utf8
#> ctype English_Canada.utf8
#> tz America/Edmonton
#> date 2024-05-01
#> pandoc 3.1.6.1 @ C:/PROGRA~1/Pandoc/ (via rmarkdown)
#>
#> - Packages -------------------------------------------------------------------
#> package * version date (UTC) lib source
#> backports 1.4.1 2021-12-13 [1] CRAN (R 4.3.0)
#> broom * 1.0.5 2023-06-09 [1] CRAN (R 4.3.1)
#> butcher 0.3.4 2024-04-11 [1] CRAN (R 4.3.3)
#> class 7.3-22 2023-05-03 [2] CRAN (R 4.3.3)
#> cli 3.6.2 2023-12-11 [1] CRAN (R 4.3.3)
#> codetools 0.2-20 2024-03-31 [1] CRAN (R 4.3.3)
#> colorspace 2.1-0 2023-01-23 [1] CRAN (R 4.3.1)
#> data.table 1.15.4 2024-03-30 [1] CRAN (R 4.3.3)
#> dials * 1.2.1 2024-02-22 [1] CRAN (R 4.3.3)
#> DiceDesign 1.10 2023-12-07 [1] CRAN (R 4.3.3)
#> digest 0.6.35 2024-03-11 [1] CRAN (R 4.3.3)
#> dplyr * 1.1.4 2023-11-17 [1] CRAN (R 4.3.3)
#> ellipsis 0.3.2 2021-04-29 [1] CRAN (R 4.3.1)
#> evaluate 0.23 2023-11-01 [1] CRAN (R 4.3.2)
#> fansi 1.0.6 2023-12-08 [1] CRAN (R 4.3.3)
#> fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.3.1)
#> forcats * 1.0.0 2023-01-29 [1] CRAN (R 4.3.1)
#> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.3.1)
#> fs 1.6.3 2023-07-20 [1] CRAN (R 4.3.1)
#> furrr 0.3.1 2022-08-15 [1] CRAN (R 4.3.1)
#> future 1.33.2 2024-03-26 [1] CRAN (R 4.3.3)
#> future.apply 1.11.2 2024-03-28 [1] CRAN (R 4.3.3)
#> generics 0.1.3 2022-07-05 [1] CRAN (R 4.3.1)
#> ggplot2 * 3.5.0 2024-02-23 [1] CRAN (R 4.3.3)
#> glmnet * 4.1-8 2023-08-22 [1] CRAN (R 4.3.2)
#> globals 0.16.3 2024-03-08 [1] CRAN (R 4.3.3)
#> glue 1.7.0 2024-01-09 [1] CRAN (R 4.3.3)
#> gower 1.0.1 2022-12-22 [1] CRAN (R 4.3.0)
#> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.3.1)
#> gtable 0.3.4 2023-08-21 [1] CRAN (R 4.3.1)
#> hardhat 1.3.1 2024-02-02 [1] CRAN (R 4.3.3)
#> hms 1.1.3 2023-03-21 [1] CRAN (R 4.3.1)
#> htmltools 0.5.8.1 2024-04-04 [1] CRAN (R 4.3.3)
#> infer * 1.0.7 2024-03-25 [1] CRAN (R 4.3.3)
#> ipred 0.9-14 2023-03-09 [1] CRAN (R 4.3.1)
#> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.3.1)
#> kernlab * 0.9-32 2023-01-31 [1] CRAN (R 4.3.0)
#> knitr 1.46 2024-04-06 [1] CRAN (R 4.3.3)
#> lattice 0.22-6 2024-03-20 [1] CRAN (R 4.3.3)
#> lava 1.8.0 2024-03-05 [1] CRAN (R 4.3.3)
#> lhs 1.1.6 2022-12-17 [1] CRAN (R 4.3.1)
#> lifecycle 1.0.4 2023-11-07 [1] CRAN (R 4.3.2)
#> listenv 0.9.1 2024-01-29 [1] CRAN (R 4.3.3)
#> lubridate * 1.9.3 2023-09-27 [1] CRAN (R 4.3.3)
#> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.1)
#> MASS 7.3-60.0.1 2024-01-13 [1] CRAN (R 4.3.3)
#> Matrix * 1.6-5 2024-01-11 [1] CRAN (R 4.3.2)
#> modeldata * 1.3.0 2024-01-21 [1] CRAN (R 4.3.3)
#> munsell 0.5.1 2024-04-01 [1] CRAN (R 4.3.3)
#> nnet 7.3-19 2023-05-03 [1] CRAN (R 4.3.1)
#> parallelly 1.37.1 2024-02-29 [1] CRAN (R 4.3.3)
#> parsnip * 1.2.1 2024-03-22 [1] CRAN (R 4.3.3)
#> pillar 1.9.0 2023-03-22 [1] CRAN (R 4.3.1)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.3.1)
#> prodlim 2023.08.28 2023-08-28 [1] CRAN (R 4.3.3)
#> purrr * 1.0.2 2023-08-10 [1] CRAN (R 4.3.1)
#> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.3.1)
#> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.3.0)
#> R.oo 1.26.0 2024-01-24 [1] CRAN (R 4.3.3)
#> R.utils 2.12.3 2023-11-18 [1] CRAN (R 4.3.2)
#> R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.2)
#> Rcpp 1.0.12 2024-01-09 [1] CRAN (R 4.3.3)
#> readr * 2.1.5 2024-01-10 [1] CRAN (R 4.3.3)
#> recipes * 1.0.10.9000 2024-03-03 [1] Github (tidymodels/recipes@7858c1e)
#> reprex 2.1.0 2024-01-11 [1] CRAN (R 4.3.3)
#> rlang 1.1.3 2024-01-10 [1] CRAN (R 4.3.3)
#> rmarkdown 2.26 2024-03-05 [1] CRAN (R 4.3.3)
#> rpart 4.1.23 2023-12-05 [1] CRAN (R 4.3.3)
#> rsample * 1.2.1 2024-03-25 [1] CRAN (R 4.3.3)
#> scales * 1.3.0 2023-11-28 [1] CRAN (R 4.3.2)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.3.1)
#> shape 1.4.6.1 2024-02-23 [1] CRAN (R 4.3.2)
#> stacks * 1.0.4.9000 2024-04-29 [1] Github (tidymodels/stacks@f95369f)
#> stringi 1.8.3 2023-12-11 [1] CRAN (R 4.3.2)
#> stringr * 1.5.1 2023-11-14 [1] CRAN (R 4.3.2)
#> styler 1.10.3 2024-04-07 [1] CRAN (R 4.3.3)
#> survival 3.5-8 2024-02-14 [1] CRAN (R 4.3.3)
#> tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.3.1)
#> tidymodels * 1.2.0 2024-03-25 [1] CRAN (R 4.3.3)
#> tidyr * 1.3.1 2024-01-24 [1] CRAN (R 4.3.3)
#> tidyselect 1.2.1 2024-03-11 [1] CRAN (R 4.3.3)
#> tidyverse * 2.0.0 2023-02-22 [1] CRAN (R 4.3.1)
#> timechange 0.3.0 2024-01-18 [1] CRAN (R 4.3.3)
#> timeDate 4032.109 2023-12-14 [1] CRAN (R 4.3.2)
#> tune * 1.2.0 2024-03-20 [1] CRAN (R 4.3.3)
#> tzdb 0.4.0 2023-05-12 [1] CRAN (R 4.3.1)
#> utf8 1.2.4 2023-10-22 [1] CRAN (R 4.3.1)
#> vctrs 0.6.5 2023-12-01 [1] CRAN (R 4.3.3)
#> withr 3.0.0 2024-01-16 [1] CRAN (R 4.3.3)
#> workflows * 1.1.4 2024-02-19 [1] CRAN (R 4.3.3)
#> workflowsets * 1.1.0 2024-03-21 [1] CRAN (R 4.3.3)
#> xfun 0.43 2024-03-25 [1] CRAN (R 4.3.3)
#> yaml 2.3.8 2023-12-11 [1] CRAN (R 4.3.2)
#> yardstick * 1.3.1 2024-03-21 [1] CRAN (R 4.3.3)
#>
#> [1] C:/Users/pgaut/AppData/Local/R/win-library/4.3
#> [2] C:/Program Files/R/R-4.3.3/library
#>
#> ------------------------------------------------------------------------------
Thanks for the informative issue description! stacks implements butcher methods for model stacks that do exactly what you're looking for. Calling butcher() on your stack_finalized object should trim off all of the components of the stack that aren't needed for prediction. This can often reduce the object size in memory by an order of magnitude or more.
If you find that the butchered model stack contains some larger objects that aren't needed for prediction (you can uncover these with butcher::weigh() and manually trimming them off), definitely feel free to mention these here and we can look into updating the relevant butcher method.
Thanks, butcher seems exactly what I was hoping for. I am getting some discrepancies between the size of the butchered model stack within the R environment vs size of the butchered model stack saved to disk. Maybe unrelated to stacks, but my knowledge of how to interpret this is lacking. Any suggestions or redirects?
# HOUSEKEEPING ####
rm(list = ls(all = TRUE)) # clean house
# CRAN libraries
library(tidyverse) # install.packages("tidyverse")
library(tidymodels) # install.packages("tidymodels")
library(stacks) # install.packages("stacks")
library(tictoc) #install.packages("tictoc")
# DEFINE SEED ####
seed <- 42
# SOME RANDOM DATA TO INCREASE OBJECT SIZE ####
cols_extra <- matrix(runif(121200, min = 1, max = 10), nrow = 1212, ncol = 10000) %>%
as_tibble(.name_repair = tidyr_legacy)
# DEFINE SET ####
tree_frogs_class_train <- tree_frogs %>%
dplyr::select(-c(clutch, latency)) %>%
bind_cols(cols_extra)
# SET UP RESAMPLING ####
set.seed(seed)
cv_folds <- tree_frogs_class_train %>% rsample::vfold_cv(v = 5)
# BUILD MODELS ####
mod_lasso <- parsnip::logistic_reg(penalty = tune(), mixture = 1) %>%
parsnip::set_engine("glmnet") %>%
parsnip::set_mode("classification")
mod_elasticnet <- parsnip::logistic_reg(penalty = tune(), mixture = tune()) %>%
parsnip::set_engine("glmnet") %>%
parsnip::set_mode("classification")
# BUILD RECIPE ####
recipe_set <- recipes::recipe(hatched ~ ., data = tree_frogs_class_train) %>%
recipes::step_dummy(recipes::all_nominal(), -hatched) %>%
recipes::step_zv(recipes::all_predictors()) %>%
recipes::step_normalize(recipes::all_numeric())
# CREATE WORKFLOW ####
workflow_set <- workflowsets::workflow_set(
preproc = list(recipe_set),
models = list(
lasso = mod_lasso,
elasticnet = mod_elasticnet
)
)
# TUNING ####
res_tune <- workflow_set %>%
workflowsets::workflow_map(
seed = seed,
resamples = cv_folds,
fn = "tune_grid",
grid = 10,
metrics = yardstick::metric_set(roc_auc),
control = tune::control_grid(
save_pred = TRUE,
save_workflow = TRUE,
verbose = FALSE,
allow_par = FALSE,
)
)
#> ℹ The workflow being saved contains a recipe, which is 98.73 Mb in ℹ memory. If
#> this was not intentional, please set the control setting ℹ `save_workflow =
#> FALSE`.
#> ℹ The workflow being saved contains a recipe, which is 98.73 Mb in ℹ memory. If
#> this was not intentional, please set the control setting ℹ `save_workflow =
#> FALSE`.
# DEFINE ENSEMBLE STACK ####
stack_finalized <- stacks::stacks() %>%
stacks::add_candidates(res_tune) %>%
stacks::blend_predictions() %>%
stacks::fit_members()
#> Warning: Predictions from 14 candidates were identical to those from existing candidates
#> and were removed from the data stack.
# BUTCHER THE MODEL_STACK ####
stack_finalized_butchered <- stack_finalized %>%
stacks::butcher()
# SAVE THE BUTCHERED MODEL_STACK TO DISK ####
path <- "C:/Users/Public/stack_finalized_butchered.RData"
tic()
save(stack_finalized_butchered, file = path)
toc()
#> 101.2 sec elapsed
# CHECK THE SIZE OF THE MODEL_STACK ####
stack_finalized %>%
lobstr::obj_size()
#> 415.43 MB
# CHECK THE SIZE OF THE BUTCHERED MODEL_STACK ####
stack_finalized_butchered %>%
lobstr::obj_size()
#> 122.69 MB
# CHECK THE SIZE OF THE BUTCHERED MODEL_STACK .RData FILE ####
file.size(path) / 1048576
#> [1] 741.9729
Created on 2024-05-01 with reprex v2.1.0
Just confirming that I'm able to reproduce this. Somewhat reminds me of #117—may need to revisit.
For a stack st, the slot attr(st$coefs$preproc$terms, ".Environment") is inflated after saving. This can be replicated even without butcher.
library(tidymodels)
library(stacks)
st <-
stacks() %>%
add_candidates(class_res_nn) %>%
add_candidates(class_res_rf) %>%
blend_predictions() %>%
fit_members()
#> Warning: Predictions from 1 candidate were identical to those from existing candidates
#> and were removed from the data stack.
path <- "st.RData"
save(st, file = path)
lobstr::obj_size(st)
#> 16.77 MB
st_orig <- st
rm(st)
load(path)
# the size is inflated after save...
lobstr::obj_size(st)
#> 18.14 MB
That difference is due to coefs.preproc.terms. It's only slight here, but in the OP's post, this difference is on the order of gigabytes:
butcher::weigh(st_orig)
#> # A tibble: 150,938 × 2
#> object size
#> <chr> <dbl>
#> 1 coefs.preproc.terms 0.640
#> 2 model_defs.class_res_nn.pre.actions.recipe.recipe.steps.terms 0.115
#> 3 model_defs.class_res_rf.pre.actions.recipe.recipe.steps.terms 0.115
#> 4 member_fits.class_res_nn_1_1.pre.actions.recipe.recipe.steps.terms 0.115
#> 5 member_fits.class_res_nn_1_1.pre.mold.blueprint.recipe.steps.terms 0.115
#> 6 member_fits.class_res_rf_1_06.pre.actions.recipe.recipe.steps.terms 0.115
#> 7 member_fits.class_res_rf_1_06.pre.mold.blueprint.recipe.steps.terms 0.115
#> 8 member_fits.class_res_rf_1_10.pre.actions.recipe.recipe.steps.terms 0.115
#> 9 member_fits.class_res_rf_1_10.pre.mold.blueprint.recipe.steps.terms 0.115
#> 10 member_fits.class_res_rf_1_03.pre.actions.recipe.recipe.steps.terms 0.115
#> # ℹ 150,928 more rows
butcher::weigh(st)
#> # A tibble: 150,938 × 2
#> object size
#> <chr> <dbl>
#> 1 coefs.preproc.terms 0.664
#> 2 model_defs.class_res_nn.pre.actions.recipe.recipe.steps.terms 0.115
#> 3 model_defs.class_res_rf.pre.actions.recipe.recipe.steps.terms 0.115
#> 4 member_fits.class_res_nn_1_1.pre.actions.recipe.recipe.steps.terms 0.115
#> 5 member_fits.class_res_nn_1_1.pre.mold.blueprint.recipe.steps.terms 0.115
#> 6 member_fits.class_res_rf_1_06.pre.actions.recipe.recipe.steps.terms 0.115
#> 7 member_fits.class_res_rf_1_06.pre.mold.blueprint.recipe.steps.terms 0.115
#> 8 member_fits.class_res_rf_1_10.pre.actions.recipe.recipe.steps.terms 0.115
#> 9 member_fits.class_res_rf_1_10.pre.mold.blueprint.recipe.steps.terms 0.115
#> 10 member_fits.class_res_rf_1_03.pre.actions.recipe.recipe.steps.terms 0.115
#> # ℹ 150,928 more rows
butcher::weigh(attributes(st_orig$coefs$preproc$terms))
#> # A tibble: 10 × 2
#> object size
#> <chr> <dbl>
#> 1 .Environment 0.628
#> 2 factors 0.00442
#> 3 variables 0.00258
#> 4 predvars 0.00258
#> 5 dataClasses 0.00241
#> 6 term.labels 0.00190
#> 7 order 0.000176
#> 8 class 0.000176
#> 9 intercept 0.000056
#> 10 response 0.000056
butcher::weigh(attributes(st$coefs$preproc$terms))
#> # A tibble: 10 × 2
#> object size
#> <chr> <dbl>
#> 1 .Environment 0.652
#> 2 factors 0.00442
#> 3 variables 0.00258
#> 4 predvars 0.00258
#> 5 dataClasses 0.00241
#> 6 term.labels 0.00190
#> 7 order 0.000176
#> 8 class 0.000176
#> 9 intercept 0.000056
#> 10 response 0.000056
Strangely, this difference doesn't show up in a waldo::compare() call. Likely a shared reference:
waldo::compare(
attr(st_orig$coefs$preproc$terms, ".Environment"),
attr(st$coefs$preproc$terms, ".Environment")
)
#> ✔ No differences
Created on 2024-07-15 with reprex v2.1.0