Feature suggestion: Extract splits from tune results as a resampling object
Feature suggestion
Now that we have the new {tailor} package for post-processing in titydmodels, I find myself in the need to reuse the splits from tune_results as a resampling object.
I believe this new extract_resamples function (or whatever name you prefer) could improve the interactive usage of tidymodels.
Here a minimal reproducible example to demonstrate its use:
# pak::pak(
# paste0(
# "tidymodels/",
# c("tune", "workflows", "rsample", "tailor")
# )
# )
library(tidyverse)
library(tidymodels)
library(probably)
#>
#> Attaching package: 'probably'
#> The following objects are masked from 'package:base':
#>
#> as.factor, as.ordered
library(tailor)
library(stacks)
# How well are our predictions calibrated? Not so well
data(deliveries)
set.seed(1)
delivery_split <- initial_split(deliveries)
delivery_train <- training(delivery_split)
delivery_test <- testing(delivery_split)
set.seed(1)
delivery_folds <- vfold_cv(delivery_train)
delivery_res <-
workflow() %>%
add_formula(time_to_delivery ~ .) %>%
add_model(boost_tree(mode = "regression", trees = 3)) |>
fit_resamples(
delivery_folds,
control = control_stack_resamples()
)
delivery_res |>
collect_predictions() |>
cal_plot_regression(truth = time_to_delivery, estimate = .pred)
delivery_res |> collect_metrics()
#> # A tibble: 2 × 6
#> .metric .estimator mean n std_err .config
#> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 rmse standard 9.52 10 0.0533 Preprocessor1_Model1
#> 2 rsq standard 0.853 10 0.00357 Preprocessor1_Model1
# We want to reuse the already saved splits in the tune results as rset
extract_resamples <- \(x) {
stopifnot(inherits(x, "tune_results"))
result_rset <- manual_rset(x$splits, x$id)
new_attrs <- attributes(result_rset)[c("names", "row.names")]
existing_attrs <- attributes(x)$rset_info$att
att <- modifyList(existing_attrs, new_attrs)
desired_classes <- c(att$class, "rset", "tbl_df", "tbl", "data.frame")
att$class <- NULL
attributes(result_rset) <- att
class(result_rset) <- desired_classes
result_rset
}
waldo::compare(delivery_folds, extract_resamples(delivery_res))
#> ✔ No differences
# Let's adjust numeric calibration extracting the saved splits
delivery_res_improved <-
delivery_res |>
extract_workflow() |>
add_tailor(tailor() %>% adjust_numeric_calibration()) |>
fit_resamples(
extract_resamples(delivery_res),
control = control_stack_resamples()
)
delivery_res_improved |> collect_metrics()
#> # A tibble: 2 × 6
#> .metric .estimator mean n std_err .config
#> <chr> <chr> <dbl> <int> <dbl> <chr>
#> 1 rmse standard 2.71 10 0.0300 Preprocessor1_Model1
#> 2 rsq standard 0.846 10 0.00432 Preprocessor1_Model1
# Much better
delivery_res_improved |>
collect_predictions() |>
cal_plot_regression(truth = time_to_delivery, estimate = .pred)
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.3.3 (2024-02-29)
#> os Ubuntu 22.04.4 LTS
#> system x86_64, linux-gnu
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Europe/Madrid
#> date 2024-10-09
#> pandoc 2.9.2.1 @ /bin/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date (UTC) lib source
#> backports 1.4.1 2021-12-13 [1] CRAN (R 4.3.0)
#> broom * 1.0.5 2023-06-09 [1] CRAN (R 4.3.1)
#> butcher 0.3.3 2023-08-23 [1] CRAN (R 4.3.2)
#> class 7.3-22 2023-05-03 [2] CRAN (R 4.3.3)
#> cli 3.6.2 2023-12-11 [1] CRAN (R 4.3.2)
#> codetools 0.2-19 2023-02-01 [2] CRAN (R 4.3.3)
#> colorspace 2.1-0 2023-01-23 [1] CRAN (R 4.3.0)
#> data.table 1.15.99 2024-02-20 [1] Github (Rdatatable/data.table@8f8ef93)
#> dials * 1.3.0 2024-07-30 [1] RSPM
#> DiceDesign 1.10 2023-12-07 [1] CRAN (R 4.3.2)
#> digest 0.6.35 2024-03-11 [1] RSPM (R 4.3.0)
#> dplyr * 1.1.4 2023-11-17 [1] CRAN (R 4.3.2)
#> evaluate 0.23 2023-11-01 [1] CRAN (R 4.3.2)
#> fansi 1.0.6 2023-12-08 [1] CRAN (R 4.3.2)
#> farver 2.1.1 2022-07-06 [1] CRAN (R 4.3.0)
#> fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.3.0)
#> forcats * 1.0.0 2023-01-29 [1] CRAN (R 4.3.2)
#> foreach 1.5.2 2022-02-02 [1] CRAN (R 4.3.0)
#> fs 1.6.3 2023-07-20 [1] CRAN (R 4.3.1)
#> furrr 0.3.1 2022-08-15 [1] CRAN (R 4.3.0)
#> future 1.33.1 2023-12-22 [1] CRAN (R 4.3.2)
#> future.apply 1.11.1 2023-12-21 [1] CRAN (R 4.3.2)
#> generics 0.1.3 2022-07-05 [1] CRAN (R 4.3.0)
#> ggplot2 * 3.5.0 2024-02-23 [1] RSPM (R 4.3.0)
#> globals 0.16.3 2024-03-08 [1] RSPM (R 4.3.0)
#> glue 1.7.0 2024-01-09 [1] RSPM (R 4.3.0)
#> gower 1.0.1 2022-12-22 [1] CRAN (R 4.3.0)
#> GPfit 1.0-8 2019-02-08 [1] CRAN (R 4.3.0)
#> gtable 0.3.4 2023-08-21 [1] CRAN (R 4.3.1)
#> hardhat 1.4.0 2024-06-02 [1] RSPM
#> hms 1.1.3 2023-03-21 [1] CRAN (R 4.3.0)
#> htmltools 0.5.8 2024-03-25 [1] RSPM (R 4.3.0)
#> infer * 1.0.7 2024-03-25 [1] RSPM (R 4.3.0)
#> ipred 0.9-14 2023-03-09 [1] CRAN (R 4.3.0)
#> iterators 1.0.14 2022-02-05 [1] CRAN (R 4.3.0)
#> jsonlite 1.8.8 2023-12-04 [1] CRAN (R 4.3.2)
#> knitr 1.45 2023-10-30 [1] CRAN (R 4.3.2)
#> labeling 0.4.3 2023-08-29 [1] CRAN (R 4.3.1)
#> lattice 0.22-5 2023-10-24 [2] CRAN (R 4.3.3)
#> lava 1.8.0 2024-03-05 [1] RSPM (R 4.3.0)
#> lhs 1.1.6 2022-12-17 [1] CRAN (R 4.3.0)
#> lifecycle 1.0.4 2023-11-07 [1] CRAN (R 4.3.2)
#> listenv 0.9.1 2024-01-29 [1] RSPM (R 4.3.0)
#> lubridate * 1.9.3 2023-09-27 [1] CRAN (R 4.3.2)
#> magrittr 2.0.3 2022-03-30 [1] CRAN (R 4.3.0)
#> MASS 7.3-60.0.1 2024-01-13 [2] CRAN (R 4.3.3)
#> Matrix 1.6-5 2024-01-11 [1] RSPM (R 4.3.0)
#> mgcv 1.9-1 2023-12-21 [2] CRAN (R 4.3.3)
#> modeldata * 1.3.0 2024-01-21 [1] RSPM (R 4.3.0)
#> modelenv 0.1.1 2023-03-08 [1] CRAN (R 4.3.0)
#> munsell 0.5.0 2018-06-12 [1] CRAN (R 4.3.0)
#> nlme 3.1-164 2023-11-27 [2] CRAN (R 4.3.3)
#> nnet 7.3-19 2023-05-03 [2] CRAN (R 4.3.3)
#> parallelly 1.37.1 2024-02-29 [1] RSPM (R 4.3.0)
#> parsnip * 1.2.1.9002 2024-10-08 [1] Github (tidymodels/parsnip@5ce414e)
#> pillar 1.9.0 2023-03-22 [1] CRAN (R 4.3.0)
#> pkgconfig 2.0.3 2019-09-22 [1] CRAN (R 4.3.0)
#> probably * 1.0.3.9001 2024-10-08 [1] Github (tidymodels/probably@545f9ab)
#> prodlim 2023.08.28 2023-08-28 [1] CRAN (R 4.3.2)
#> purrr * 1.0.2 2023-08-10 [1] CRAN (R 4.3.1)
#> R.cache 0.16.0 2022-07-21 [1] CRAN (R 4.3.1)
#> R.methodsS3 1.8.2 2022-06-13 [1] CRAN (R 4.3.1)
#> R.oo 1.26.0 2024-01-24 [1] CRAN (R 4.3.2)
#> R.utils 2.12.3 2023-11-18 [1] CRAN (R 4.3.2)
#> R6 2.5.1 2021-08-19 [1] CRAN (R 4.3.0)
#> Rcpp 1.0.12 2024-01-09 [1] RSPM (R 4.3.0)
#> readr * 2.1.5 2024-01-10 [1] RSPM (R 4.3.0)
#> recipes * 1.0.10 2024-02-18 [1] RSPM (R 4.3.0)
#> reprex 2.1.0.9000 2024-01-18 [1] Github (tidyverse/reprex@e1f65e9)
#> rlang 1.1.3 2024-01-10 [1] RSPM (R 4.3.0)
#> rmarkdown 2.26 2024-03-05 [1] RSPM (R 4.3.0)
#> rpart 4.1.23 2023-12-05 [1] RSPM
#> rsample * 1.2.1.9000 2024-10-08 [1] Github (tidymodels/rsample@f799dba)
#> scales * 1.3.0 2023-11-28 [1] CRAN (R 4.3.2)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.3.0)
#> sparsevctrs 0.1.0.9002 2024-10-08 [1] Github (r-lib/sparsevctrs@b29b723)
#> stacks * 1.0.4 2024-03-21 [1] RSPM (R 4.3.0)
#> stringi 1.8.3 2023-12-11 [1] CRAN (R 4.3.2)
#> stringr * 1.5.1 2023-11-14 [1] CRAN (R 4.3.2)
#> styler 1.10.2 2023-08-29 [1] CRAN (R 4.3.2)
#> survival 3.5-8 2024-02-14 [2] CRAN (R 4.3.3)
#> tailor * 0.0.0.9001 2024-10-08 [1] Github (tidymodels/tailor@317a4db)
#> tibble * 3.2.1 2023-03-20 [1] CRAN (R 4.3.0)
#> tidymodels * 1.2.0 2024-03-25 [1] RSPM (R 4.3.0)
#> tidyr * 1.3.1 2024-01-24 [1] CRAN (R 4.3.2)
#> tidyselect 1.2.1 2024-03-11 [1] RSPM (R 4.3.0)
#> tidyverse * 2.0.0.9000 2024-02-20 [1] Github (tidyverse/tidyverse@62f32d4)
#> timechange 0.3.0 2024-01-18 [1] RSPM (R 4.3.0)
#> timeDate 4032.109 2023-12-14 [1] CRAN (R 4.3.2)
#> tune * 1.2.1.9000 2024-10-08 [1] Github (tidymodels/tune@f8d734a)
#> tzdb 0.4.0 2023-05-12 [1] CRAN (R 4.3.0)
#> utf8 1.2.4 2023-10-22 [1] CRAN (R 4.3.2)
#> vctrs 0.6.5 2023-12-01 [1] CRAN (R 4.3.2)
#> waldo 0.5.2 2023-11-02 [1] CRAN (R 4.3.2)
#> withr 3.0.0 2024-01-16 [1] CRAN (R 4.3.2)
#> workflows * 1.1.4.9000 2024-10-08 [1] Github (tidymodels/workflows@78aa5df)
#> workflowsets * 1.1.0 2024-03-21 [1] RSPM (R 4.3.0)
#> xfun 0.43 2024-03-25 [1] RSPM (R 4.3.0)
#> xgboost * 1.7.7.1 2024-01-25 [1] RSPM (R 4.3.0)
#> yaml 2.3.8 2023-12-11 [1] CRAN (R 4.3.2)
#> yardstick * 1.3.1 2024-03-21 [1] RSPM (R 4.3.0)
#>
#> [1] /home/jordi/R/x86_64-pc-linux-gnu-library/4.3
#> [2] /opt/R/4.3.3/lib/R/library
#>
#> ──────────────────────────────────────────────────────────────────────────────
Created on 2024-10-09 with [reprex v2.1.0.9000](https://reprex.tidyverse.org/)
This implementation seems to give identical results for my vfold_cv example, but I guess other rset type of objects should be tested.
Could you say a little bit more about why it is that you'd need to extract the splits from the tune_results rather than just reusing the splits you have already?
Note to self: FWIW, we did find a use for a similar helper in stacks:::.set_splits().
Well. In my pipelines I usually have one process for fitting resamples & tuning and sometimes I only save the tune_resamples object and not the rset... But, then "ups" I need the rset too because I want to check something and I didnt save it. {tailor} could increase the probability of this issue.
Furthermore, I want to try AutoGuon inference approach and this function could help.
Gotcha, thanks for the reply! I will leave this open as we can see some use cases for this, though it may not be at the top of our to-do for a bit.
If you guide me on what tests do you like to include, I would make a proper PR, so we can merge it.