tidymodels.org-legacy
                                
                                
                                
                                    tidymodels.org-legacy copied to clipboard
                            
                            
                            
                        Again Error for Unusual Split of Data into Test and Train Sets
Dear All, I am going back to tidymodels and I dusted off the example discussed at https://github.com/tidymodels/tidymodels.org/issues/198 . Please have a look at the reprex.
Now I get a new error/warning (and I am sure I did not have this specific issue in the past) which seems to be due to the train data consisting only of one data point, but I am not 100% sure of this nor do I know what to do if this is the case (I really need this unusual data split). I have two questions
- can anyone tell me what to do to fix my code?
 - given that my data has a strong time component (in df_ini there is a "year" column), I think I should use a different resampling technique (see https://www.tmwr.org/resampling.html#rolling ). It must be a one-liner, but I am experiencing some issues. Can anyone show me how to implement in my example (once it has been fixed) the rolling forecasting origin resampling, e.g. non cumulative, analysis size of eight samples (eight years) and an assessment set size of two (two years)?
 
Thanks a lot!
library(tidymodels)
tidymodels_prefer() 
df_ini <- structure(list(year = c(1998, 2002, 2004, 2005, 2006, 2007, 2008, 
2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018), 
    capital_n1132g_lag_1 = c(3446.5, 4091.1, 3655.1, 3633.3, 
    3616.2, 3450.7, 3596.8, 3867.2, 3372.5, 3722.9, 3808.5, 4005.6, 
    3718.6, 3467.9, 4214.2, 4237.4, 4450.2), capital_n117g_lag_1 = c(4920.9, 
    7810.6, 8560.3, 8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 
    11554.3, 11849.9, 13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 
    17647.1, 18273.8), capital_n11mg_lag_1 = c(16846, 19605, 
    19381.2, 19433.5, 20051.6, 20569.8, 22646.1, 23674.5, 21200.6, 
    20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 25019.2, 27608.2, 
    29790.1), employment_be_lag_1 = c(2834.42, 2839.72, 2765.53, 
    2731.08, 2709.59, 2708.39, 2774.06, 2795.6, 2703.36, 2668.1, 
    2705.1, 2731.67, 2727.16, 2725.66, 2735.69, 2750.52, 2782.9
    ), employment_c_lag_1 = c(2612.76, 2623.69, 2552.89, 2518.57, 
    2496.98, 2499.54, 2558.88, 2578, 2483.97, 2447.65, 2483.1, 
    2507.41, 2500.94, 2499.6, 2511.75, 2523.97, 2555.48), employment_j_lag_1 = c(292.93, 
    389.2, 389.45, 387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 
    392.18, 410.85, 419.75, 427.59, 438.96, 440.33, 460.84, 473.4
    ), employment_k_lag_1 = c(505.33, 507.12, 510.25, 504.63, 
    515.39, 523.45, 536.6, 550.14, 546.68, 539.96, 536.58, 534.98, 
    524.13, 518.89, 511.57, 505.32, 496.41), employment_mn_lag_1 = c(945.59, 
    1217.96, 1289.55, 1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 
    1704.65, 1762.55, 1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 
    2021.51, 2109.71), employment_oq_lag_1 = c(3065.87, 3191.75, 
    3280.36, 3317.09, 3401.65, 3476.63, 3508.01, 3577.75, 3683.85, 
    3759.23, 3798.35, 3850.17, 3877.24, 3924.06, 4002.74, 4095.59, 
    4171.72), employment_total_lag_1 = c(14509.58, 15127.99, 
    15212.11, 15307.28, 15491.61, 15762.92, 16050.92, 16356.53, 
    16269.97, 16392.87, 16647.79, 16820.66, 16879.06, 17039.6, 
    17142.13, 17365.32, 17650.21), gdp_b1gq_lag_1 = c(187849.7, 
    220525, 231862.5, 242348.3, 254075, 267824.4, 283978, 293761.9, 
    288044.1, 295896.6, 310128.6, 318653.1, 323910.2, 333146.1, 
    344269.3, 357608, 369341.3), gdp_p3_lag_1 = c(139695.2, 161175.8, 
    169405.6, 176316.4, 185871.1, 194102, 200944.4, 208857.1, 
    213630.1, 218947.2, 227250.8, 233638.1, 238329.3, 243860.6, 
    249404.3, 257166.5, 265900.2), gdp_p61_lag_1 = c(50117.6, 
    71948.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2, 113368.1, 
    91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 126109.3, 
    129183.6, 131524, 140057.8), gdp_p62_lag_1 = c(19441, 26444.4, 
    28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8, 38781.9, 
    39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 55885.5, 
    59584.7), price_index_lag_1 = c(1.2, 2.3, 1.3, 2, 2.1, 1.7, 
    2.2, 3.2, 0.4, 1.7, 3.6, 2.6, 2.1, 1.5, 0.8, 1, 2.2), value_be_lag_1 = c(40533.1, 
    48207.1, 48673.2, 50737.6, 52955.2, 56872.4, 60864.9, 61029, 
    56837.8, 58433.6, 61443, 63655.1, 64132.3, 65542.6, 67495.4, 
    71152.6, 72698.8), value_c_lag_1 = c(33441.8, 40446.6, 40467.4, 
    42014.6, 44229, 47735.5, 51552.4, 51165.9, 47129.7, 48759.3, 
    51467.7, 53234.6, 53431.4, 55169, 57458.7, 60962.8, 62196
    ), value_j_lag_1 = c(5483.7, 7326.1, 7934.1, 7756.1, 8134.2, 
    8378.8, 8532.3, 8740, 8493.9, 8518.9, 9217.1, 9405.1, 9802.1, 
    10361.4, 10695.4, 11455.3, 11720.6), value_k_lag_1 = c(9210.6, 
    9977.3, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7, 13205.2, 
    12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 13482.9, 
    13236.4, 13744.1), value_mn_lag_1 = c(10444, 14061.4, 15706.6, 
    16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 23255.2, 
    24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7, 32259.6
    ), value_oq_lag_1 = c(29902.7, 34179.2, 36126.8, 37329.6, 
    38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 47980.9, 
    49381.5, 50261.7, 51624.3, 53715, 55926.4, 57637.1), value_total_lag_1 = c(167323.4, 
    197076.7, 207247.6, 216098.3, 225888.1, 239076, 253604.6, 
    262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 297230.1, 
    307037.7, 318952.7, 329396.1), capital_n1132g_lag_2 = c(3599.2, 
    3996.9, 3638.4, 3655.1, 3633.3, 3616.2, 3450.7, 3596.8, 3867.2, 
    3372.5, 3722.9, 3808.5, 4005.6, 3718.6, 3467.9, 4214.2, 4237.4
    ), capital_n117g_lag_2 = c(4636.2, 7008.5, 8369.6, 8560.3, 
    8679.9, 8938.9, 9823.8, 10467.1, 11047.1, 11554.3, 11849.9, 
    13465.4, 13927.5, 15510.2, 15754.4, 16584.7, 17647.1), capital_n11mg_lag_2 = c(17181.5, 
    19677.8, 18749.6, 19381.2, 19433.5, 20051.6, 20569.8, 22646.1, 
    23674.5, 21200.6, 20919.6, 23157.7, 23520.7, 24057.7, 23832.8, 
    25019.2, 27608.2), employment_be_lag_2 = c(2870.33, 2840.19, 
    2775.22, 2765.53, 2731.08, 2709.59, 2708.39, 2774.06, 2795.6, 
    2703.36, 2668.1, 2705.1, 2731.67, 2727.16, 2725.66, 2735.69, 
    2750.52), employment_c_lag_2 = c(2626.2, 2621.08, 2562.53, 
    2552.89, 2518.57, 2496.98, 2499.54, 2558.88, 2578, 2483.97, 
    2447.65, 2483.1, 2507.41, 2500.94, 2499.6, 2511.75, 2523.97
    ), employment_j_lag_2 = c(275.08, 374.56, 400.75, 389.45, 
    387.53, 384.64, 389.29, 385.77, 392.86, 383.91, 392.18, 410.85, 
    419.75, 427.59, 438.96, 440.33, 460.84), employment_k_lag_2 = c(500.9, 
    505.13, 502.42, 510.25, 504.63, 515.39, 523.45, 536.6, 550.14, 
    546.68, 539.96, 536.58, 534.98, 524.13, 518.89, 511.57, 505.32
    ), employment_mn_lag_2 = c(904.38, 1143.78, 1248.01, 1289.55, 
    1365.29, 1425.81, 1537.88, 1622.95, 1727.76, 1704.65, 1762.55, 
    1838.16, 1896.09, 1929.09, 1950.02, 1968.83, 2021.51), employment_oq_lag_2 = c(3028.85, 
    3162.77, 3241.36, 3280.36, 3317.09, 3401.65, 3476.63, 3508.01, 
    3577.75, 3683.85, 3759.23, 3798.35, 3850.17, 3877.24, 3924.06, 
    4002.74, 4095.59), employment_total_lag_2 = c(14404.29, 15019.87, 
    15113.52, 15212.11, 15307.28, 15491.61, 15762.92, 16050.92, 
    16356.53, 16269.97, 16392.87, 16647.79, 16820.66, 16879.06, 
    17039.6, 17142.13, 17365.32), gdp_b1gq_lag_2 = c(186928.7, 
    213606.4, 226735.3, 231862.5, 242348.3, 254075, 267824.4, 
    283978, 293761.9, 288044.1, 295896.6, 310128.6, 318653.1, 
    323910.2, 333146.1, 344269.3, 357608), gdp_p3_lag_2 = c(140335.8, 
    156117.3, 164107.8, 169405.6, 176316.4, 185871.1, 194102, 
    200944.4, 208857.1, 213630.1, 218947.2, 227250.8, 233638.1, 
    238329.3, 243860.6, 249404.3, 257166.5), gdp_p61_lag_2 = c(44541.4, 
    67701.6, 74691.6, 74346.9, 83074.9, 90010.4, 100076.8, 110157.2, 
    113368.1, 91435.3, 111997.3, 123526.3, 125801.2, 123657.1, 
    126109.3, 129183.6, 131524), gdp_p62_lag_2 = c(19504.2, 24888.9, 
    28063.4, 28995.1, 30507, 33520.2, 36089.5, 39104, 43056.8, 
    38781.9, 39685.8, 43784.1, 46187.6, 49444.7, 51746, 53585.8, 
    55885.5), value_be_lag_2 = c(40076.7, 46109.4, 47967.1, 48673.2, 
    50737.6, 52955.2, 56872.4, 60864.9, 61029, 56837.8, 58433.6, 
    61443, 63655.1, 64132.3, 65542.6, 67495.4, 71152.6), value_c_lag_2 = c(32955.4, 
    38908.4, 40192.9, 40467.4, 42014.6, 44229, 47735.5, 51552.4, 
    51165.9, 47129.7, 48759.3, 51467.7, 53234.6, 53431.4, 55169, 
    57458.7, 60962.8), value_j_lag_2 = c(5576.8, 6313.9, 7737.1, 
    7934.1, 7756.1, 8134.2, 8378.8, 8532.3, 8740, 8493.9, 8518.9, 
    9217.1, 9405.1, 9802.1, 10361.4, 10695.4, 11455.3), value_k_lag_2 = c(9191, 
    10458, 10225.2, 10146.9, 10541.9, 11005.3, 11912.3, 13102.7, 
    13205.2, 12123.9, 12113.2, 12952.8, 12254.9, 12796.6, 12962.4, 
    13482.9, 13236.4), value_mn_lag_2 = c(10092, 12942.5, 15074, 
    15706.6, 16569.1, 18008.7, 19576.6, 21317, 23189.8, 22490, 
    23255.2, 24895.4, 25988.7, 26998.2, 28027.3, 29207.9, 30737.7
    ), value_oq_lag_2 = c(30224.3, 33251.5, 35065.6, 36126.8, 
    37329.6, 38288.8, 40003.1, 41511.4, 43761.3, 45817.8, 46996.6, 
    47980.9, 49381.5, 50261.7, 51624.3, 53715, 55926.4), value_total_lag_2 = c(167141.8, 
    190624.9, 202353.5, 207247.6, 216098.3, 225888.1, 239076, 
    253604.6, 262414.7, 256671, 263633.5, 276404, 283548.2, 288624.3, 
    297230.1, 307037.7, 318952.7), berd = c(2146.085, 3130.884, 
    3556.479, 4207.669, 4448.676, 4845.861, 5232.63, 5092.902, 
    5520.422, 5692.841, 6540.457, 6778.42, 7324.679, 7498.488, 
    7824.51, 7888.444, 8461.72)), row.names = c(NA, -17L), class = c("tbl_df", 
"tbl", "data.frame"))
set.seed(1234)  ## to make the results reproducible
## I need a particular custom split of my dataset: the test set consists of only the most recent observation, whereas all the rest is the training set
## see https://github.com/tidymodels/rsample/issues/158
indices <-
  list(analysis   = seq(nrow(df_ini)-1), 
       assessment = nrow(df_ini)
       )
df_split <- make_splits(indices, df_ini)
## df_split <- initial_split(df_ini) ## with the default splitting,
## ## the code works
df_train <- training(df_split)
df_test <- testing(df_split)
folded_data <- vfold_cv(df_train,3)
glmnet_recipe <- 
    recipe(formula = berd ~ ., data = df_train) |> 
    update_role(year, new_role = "ID") |> 
  step_zv(all_predictors()) |> 
  step_normalize(all_predictors(), -all_nominal()) 
glmnet_spec <- 
  linear_reg(penalty = tune(), mixture = tune()) |> 
  set_mode("regression") |> 
  set_engine("glmnet") 
glmnet_workflow <- 
  workflow() |> 
  add_recipe(glmnet_recipe) |> 
  add_model(glmnet_spec) 
glmnet_grid <- tidyr::crossing(penalty = 10^seq(-6, -1, length.out = 20), mixture = c(0.05, 
    0.2, 0.4, 0.6, 0.8, 1)) 
glmnet_tune <- 
  tune_grid(glmnet_workflow, resamples = folded_data, grid = glmnet_grid,control = control_grid(save_pred = TRUE) ) 
print(collect_metrics(glmnet_tune))
#> # A tibble: 240 × 8
#>       penalty mixture .metric .estimator    mean     n std_err .config          
#>         <dbl>   <dbl> <chr>   <chr>        <dbl> <int>   <dbl> <chr>            
#>  1 0.000001      0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#>  2 0.000001      0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#>  3 0.00000183    0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#>  4 0.00000183    0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#>  5 0.00000336    0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#>  6 0.00000336    0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#>  7 0.00000616    0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#>  8 0.00000616    0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#>  9 0.0000113     0.05 rmse    standard   375.        3 48.9    Preprocessor1_Mo…
#> 10 0.0000113     0.05 rsq     standard     0.929     3  0.0420 Preprocessor1_Mo…
#> # … with 230 more rows
print(show_best(glmnet_tune, "rmse"))
#> # A tibble: 5 × 8
#>      penalty mixture .metric .estimator  mean     n std_err .config             
#>        <dbl>   <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
#> 1 0.000001      0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…
#> 2 0.00000183    0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…
#> 3 0.00000336    0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…
#> 4 0.00000616    0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…
#> 5 0.0000113     0.05 rmse    standard    375.     3    48.9 Preprocessor1_Model…
best_net <- select_best(glmnet_tune, "rmse")
final_net <- finalize_workflow(
  glmnet_workflow,
  best_net
)
final_res_net <- last_fit(final_net, df_split)
#> ! train/test split: internal: A correlation computation is required, but the inputs are size zero or o...
print(final_res_net)
#> # Resampling results
#> # Manual resampling 
#> # A tibble: 1 × 6
#>   splits         id               .metrics .notes   .predictions     .workflow 
#>   <list>         <chr>            <list>   <list>   <list>           <list>    
#> 1 <split [16/1]> train/test split <tibble> <tibble> <tibble [1 × 4]> <workflow>
#> 
#> There were issues with some computations:
#> 
#>   - Warning(s) x1: A correlation computation is required, but the inputs are size ze...
#> 
#> Run `show_notes(.Last.tune.result)` for more information.
final_fit <- final_res_net %>%
    collect_predictions()
show_notes(.Last.tune.result)
#> unique notes:
#> ────────────────────────────────────────────────────────────────────────────────
#> A correlation computation is required, but the inputs are size zero or one and the standard deviation cannot be computed. `NA` will be returned.
sessionInfo()
#> R version 4.2.1 (2022-06-23)
#> Platform: x86_64-pc-linux-gnu (64-bit)
#> Running under: Debian GNU/Linux 11 (bullseye)
#> 
#> Matrix products: default
#> BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.9.0
#> LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.9.0
#> 
#> locale:
#>  [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C              
#>  [3] LC_TIME=en_GB.UTF-8        LC_COLLATE=en_GB.UTF-8    
#>  [5] LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
#>  [7] LC_PAPER=en_GB.UTF-8       LC_NAME=C                 
#>  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
#> [11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] glmnet_4.1-4       Matrix_1.4-1       yardstick_1.0.0    workflowsets_1.0.0
#>  [5] workflows_1.0.0    tune_1.0.0         tidyr_1.2.0        tibble_3.1.7      
#>  [9] rsample_1.0.0      recipes_1.0.1      purrr_0.3.4        parsnip_1.0.0     
#> [13] modeldata_1.0.0    infer_1.0.2        ggplot2_3.3.6      dplyr_1.0.9       
#> [17] dials_1.0.0        scales_1.2.0       broom_1.0.0        tidymodels_1.0.0  
#> 
#> loaded via a namespace (and not attached):
#>  [1] splines_4.2.1      foreach_1.5.2      prodlim_2019.11.13 assertthat_0.2.1  
#>  [5] conflicted_1.1.0   highr_0.9          GPfit_1.0-8        yaml_2.3.5        
#>  [9] globals_0.15.1     ipred_0.9-13       pillar_1.7.0       backports_1.4.1   
#> [13] lattice_0.20-45    glue_1.6.2         digest_0.6.29      hardhat_1.2.0     
#> [17] colorspace_2.0-3   htmltools_0.5.2    timeDate_3043.102  pkgconfig_2.0.3   
#> [21] lhs_1.1.5          DiceDesign_1.9     listenv_0.8.0      gower_1.0.0       
#> [25] lava_1.6.10        generics_0.1.2     ellipsis_0.3.2     cachem_1.0.6      
#> [29] withr_2.5.0        furrr_0.3.0        nnet_7.3-17        cli_3.3.0         
#> [33] survival_3.3-1     magrittr_2.0.3     crayon_1.5.1       memoise_2.0.1     
#> [37] evaluate_0.15      fs_1.5.2           future_1.26.1      fansi_1.0.3       
#> [41] parallelly_1.32.0  MASS_7.3-57        class_7.3-20       tools_4.2.1       
#> [45] lifecycle_1.0.1    stringr_1.4.0      munsell_0.5.0      reprex_2.0.1      
#> [49] compiler_4.2.1     rlang_1.0.4        grid_4.2.1         iterators_1.0.14  
#> [53] rmarkdown_2.14     gtable_0.3.0       codetools_0.2-18   DBI_1.1.3         
#> [57] R6_2.5.1           lubridate_1.8.0    knitr_1.39         fastmap_1.1.0     
#> [61] future.apply_1.9.0 utf8_1.2.2         shape_1.4.6        stringi_1.7.6     
#> [65] parallel_4.2.1     Rcpp_1.0.8.3       vctrs_0.4.1        rpart_4.1.16      
#> [69] tidyselect_1.1.2   xfun_0.31
Created on 2022-08-06 by the reprex package (v2.0.1)
I should add that another used reported that the above script works on his platform with tidymodels 0.2.0, so I believe the issue is not the script per se or glmnet.