workflowsets icon indicating copy to clipboard operation
workflowsets copied to clipboard

`num_comp` not updated with `option_add()` for `discrim_flexible` workflows

Open marioem opened this issue 8 months ago • 0 comments

The problem

The usual method of finalizing tunable parameters in a workflow set is not working with num_comp for discrim_flexible workflows

Reproducible example

library(tidymodels)
library(tidyverse)
library(discrim)
#> 
#> Attaching package: 'discrim'
#> The following object is masked from 'package:dials':
#> 
#>     smoothness
library(AppliedPredictiveModeling)
library(future)

set.seed(321)
data <- quadBoundaryFunc(2000) %>% select(A = X1, B = X2, class)

data_splits1 <- initial_split(data, prop = .85, strata = class)

train_datas1 <- training(data_splits1)
test_datas1  <- testing(data_splits1)

foldss1 <- vfold_cv(train_datas1, v = 5, repeats = 3, strata = class)


biv_rec <- 
  recipe(class ~ ., data = train_datas1) %>%
  step_normalize(all_predictors())

discrim_flexible_spec <-
  discrim_flexible(num_terms =  tune::tune(), 
                   prod_degree =  tune::tune(), 
                   prune_method =  tune::tune()) %>%
  set_engine('earth') %>%
  set_mode('classification')

normalizeds1 <- 
  workflow_set(
    preproc = list(norm = biv_rec),
    models = list(FD = discrim_flexible_spec)
  )

normalizeds1 %>% extract_workflow("norm_FD") %>% extract_parameter_set_dials()
#> Collection of 3 parameters for tuning
#> 
#>    identifier         type    object
#>     num_terms    num_terms nparam[?]
#>   prod_degree  prod_degree nparam[+]
#>  prune_method prune_method dparam[+]
#> 
#> Model parameters needing finalization:
#>    # Model Terms ('num_terms')
#> 
#> See `?dials::finalize` or `?dials::update.parameters` for more information.
pars <- normalizeds1 %>% extract_workflow("norm_FD") %>% extract_parameter_set_dials() %>% finalize(x = train_datas1 %>% select(-class))

pars
#> Collection of 3 parameters for tuning
#> 
#>    identifier         type    object
#>     num_terms    num_terms nparam[+]
#>   prod_degree  prod_degree nparam[+]
#>  prune_method prune_method dparam[+]

normalizeds1 <- normalizeds1 %>% 
  option_add(param_info = pars, id = "norm_FD")

normalizeds1 %>% extract_workflow("norm_FD") %>% extract_parameter_set_dials()
#> Collection of 3 parameters for tuning
#> 
#>    identifier         type    object
#>     num_terms    num_terms nparam[?]
#>   prod_degree  prod_degree nparam[+]
#>  prune_method prune_method dparam[+]
#> 
#> Model parameters needing finalization:
#>    # Model Terms ('num_terms')
#> 
#> See `?dials::finalize` or `?dials::update.parameters` for more information.
# 'num_comp' is not getting updated

bayes_ctrl <-
  control_bayes(
    save_pred = TRUE,
    parallel_over = "everything",
    save_workflow = TRUE,
    verbose = T
  )

plan(multisession)
tune_bayes(normalizeds1 %>% extract_workflow("norm_FD"), seed = 1503, resamples = foldss1, metrics = metric_set(roc_auc, brier_class, kap, accuracy),  iter = 25, verbose = T, initial = 11, control = bayes_ctrl)
#> Error in `dials::grid_latin_hypercube()`:
#> ✖ This argument contains unknowns: `num_terms`.
#> ℹ See the `dials::finalize()` function.
# Error: This argument contains unknowns: `num_terms`.
plan(sequential)

Created on 2024-06-20 with reprex v2.1.0

Session info
sessionInfo()
#> R version 4.4.0 (2024-04-24)
#> Platform: aarch64-apple-darwin20
#> Running under: macOS Sonoma 14.5
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRblas.0.dylib 
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.0
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> time zone: UTC
#> tzcode source: internal
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] future_1.33.2                   AppliedPredictiveModeling_1.1-7
#>  [3] discrim_1.0.1                   lubridate_1.9.3                
#>  [5] forcats_1.0.0                   stringr_1.5.1                  
#>  [7] readr_2.1.5                     tidyverse_2.0.0                
#>  [9] yardstick_1.3.1                 workflowsets_1.1.0             
#> [11] workflows_1.1.4                 tune_1.2.1                     
#> [13] tidyr_1.3.1                     tibble_3.2.1                   
#> [15] rsample_1.2.1                   recipes_1.0.10                 
#> [17] purrr_1.0.2                     parsnip_1.2.1                  
#> [19] modeldata_1.3.0                 infer_1.0.7                    
#> [21] ggplot2_3.5.1                   dplyr_1.1.4                    
#> [23] dials_1.2.1                     scales_1.3.0                   
#> [25] broom_1.0.6                     tidymodels_1.2.0               
#> 
#> loaded via a namespace (and not attached):
#>  [1] rlang_1.1.4         magrittr_2.0.3      furrr_0.3.1        
#>  [4] rpart.plot_3.1.2    compiler_4.4.0      vctrs_0.6.5        
#>  [7] reshape2_1.4.4      lhs_1.1.6           pkgconfig_2.0.3    
#> [10] fastmap_1.2.0       backports_1.5.0     utf8_1.2.4         
#> [13] rmarkdown_2.27      prodlim_2023.08.28  tzdb_0.4.0         
#> [16] xfun_0.44           reprex_2.1.0        styler_1.10.3      
#> [19] parallel_4.4.0      cluster_2.1.6       R6_2.5.1           
#> [22] CORElearn_1.57.3    stringi_1.8.4       parallelly_1.37.1  
#> [25] rpart_4.1.23        Rcpp_1.0.12         iterators_1.0.14   
#> [28] knitr_1.47          future.apply_1.11.2 R.utils_2.12.3     
#> [31] Matrix_1.7-0        splines_4.4.0       nnet_7.3-19        
#> [34] R.cache_0.16.0      timechange_0.3.0    tidyselect_1.2.1   
#> [37] rstudioapi_0.16.0   yaml_2.3.8          timeDate_4032.109  
#> [40] codetools_0.2-20    listenv_0.9.1       lattice_0.22-6     
#> [43] plyr_1.8.9          withr_3.0.0         evaluate_0.23      
#> [46] survival_3.7-0      pillar_1.9.0        foreach_1.5.2      
#> [49] ellipse_0.5.0       generics_0.1.3      hms_1.1.3          
#> [52] munsell_0.5.1       plotmo_3.6.3        globals_0.16.3     
#> [55] class_7.3-22        glue_1.7.0          mda_0.5-4          
#> [58] tools_4.4.0         data.table_1.15.4   gower_1.0.1        
#> [61] fs_1.6.4            grid_4.4.0          plotrix_3.8-4      
#> [64] ipred_0.9-14        colorspace_2.1-0    earth_5.3.3        
#> [67] Formula_1.2-5       cli_3.6.2           DiceDesign_1.10    
#> [70] fansi_1.0.6         lava_1.8.0          gtable_0.3.5       
#> [73] R.methodsS3_1.8.2   GPfit_1.0-8         digest_0.6.35      
#> [76] htmltools_0.5.8.1   R.oo_1.26.0         lifecycle_1.0.4    
#> [79] hardhat_1.4.0       MASS_7.3-60.2

marioem avatar Jun 20 '24 09:06 marioem