caretEnsemble
caretEnsemble copied to clipboard
caretStack is reordering the x and y such that y is in ascending order, but is not reordering the wts
In the following example I create a simple custom caret model so that I can view the x, y and wts values being sent to the model. The easiest thing is to add browser() inside of the custom model, but I am using print statements instead that illustrate what problems it can cause.
In the example below my weights are ascending going from 0 to 1 in steps of 0.01. In theory this should have a random effect on the prediction. But because the y get sorted prior to being passed to the model, but the wts don't, the wts no longer align to the x and y rows, and even more pernicious, in the case below they cause the larger y values to be weighted higher, causing a strong distortion of the apparent weighted mean value of the series.
Minimal, reproducible example:
Minimal dataset:
set.seed(1)
df <- data.frame(x = rnorm(100), w = seq(0.01, 1, length.out = 100))
df$y <- df$x * 0.1 + rnorm(100) * 0.9
head(df)
tail(df)
Minimal, runnable code:
library(caret)
library(caretEnsemble)
# Mean Custom Caret Method
CaretMean <- list (
library = c("dplyr"),
type = "Regression",
parameters = data.frame(parameter = c("None"),
class = c("character"),
label = c("None")),
grid = function(x, y, len = NULL, search = "grid") { data.frame( None = "" ) },
fit = function(x, y, wts, param, lev, last, weights = NA, classProbs = NA, ...) {
RetVal <- list()
if (is.null(wts))
wts <- rep(1, length(y))
# Both x and y are being resorted such that y is in ascending order, however wts is not reordered.
# So the weight no longer corresponds to the correct x and y values, and can cause pernicious problems
# such as in this example the weights are also increasing meaning that the weighted mean y value is much
# higher than the unweighted mean
print(sprintf("Unweighted Mean y: %0.2f", mean(y)))
print(sprintf("Weighted Mean y: %0.2f", sum(y * wts) / sum(wts)))
# browser()
class(RetVal) <- "CaretMean"
return(RetVal)
},
predict = function(modelFit, newdata, preProc = NULL, submodels = NULL) {
sapply(1:nrow(newdata), function(R) mean(newdata[R, ]))
},
prob = NULL,
tags = c("Simple"),
label = "Mean"
)
models <- caretList(y ~ x, data = df, weights = df$w, trControl = trainControl(method = "cv", savePredictions = "final", allowParallel = F), methodList = c("glm", "gbm", "svmRadialCost", "knn"))
ensemble <- caretStack(models, method = CaretMean, weights = df$w, trControl = trainControl(method = "cv", savePredictions = "final", allowParallel = F))
Session Info:
>sessionInfo()
R version 3.4.1 (2017-06-30)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows Server >= 2012 x64 (build 9200)
Matrix products: default
locale:
[1] LC_COLLATE=English_United States.1252 LC_CTYPE=English_United States.1252 LC_MONETARY=English_United States.1252
[4] LC_NUMERIC=C LC_TIME=English_United States.1252
attached base packages:
[1] grid parallel splines stats4 stats graphics grDevices utils datasets methods base
other attached packages:
[1] caretEnsemble_2.0.0 caret_6.0-77 randomForest_4.6-12 data.table_1.10.4
[5] weights_0.85 mice_2.30 gdata_2.18.0 flexclust_1.3-4
[9] modeltools_0.2-21 magrittr_1.5 ROI_0.2-6 PortfolioAnalytics_1.0.3636
[13] PerformanceAnalytics_1.4.3541 xts_0.10-0 zoo_1.8-0 xgboost_0.6-4
[17] lubridate_1.6.0 bindrcpp_0.2 GenSA_1.1.6 optimx_2013.8.7
[21] doParallel_1.0.10 iterators_1.0.8 glmnet_2.0-13 foreach_1.4.3
[25] Matrix_1.2-10 tidyr_0.7.1 dplyr_0.7.3 plyr_1.8.4
[29] scales_0.5.0 car_2.1-5 MASS_7.3-47 DBI_0.7
[33] rsqlserver_1.0 rClr_0.7-4 VGAM_1.0-4 Hmisc_4.0-3
[37] ggplot2_2.2.1 Formula_1.2-2 survival_2.41-3 lattice_0.20-35
[41] RODBC_1.3-15
loaded via a namespace (and not attached):
[1] backports_1.1.0 lazyeval_0.2.0 svUnit_0.7-12 BB_2014.10-1 digest_0.6.12
[6] htmltools_0.3.6 checkmate_1.8.3 memoise_1.1.0 cluster_2.0.6 recipes_0.1.0
[11] gower_0.1.2 dimRed_0.1.0 colorspace_1.3-2 lme4_1.1-13 Rglpk_0.6-3
[16] bindr_0.1 glue_1.1.1 DRR_0.0.2 registry_0.3 gtable_0.2.0
[21] ipred_0.9-6 MatrixModels_0.4-1 kernlab_0.9-25 ddalpha_1.2.1 DEoptimR_1.0-8
[26] SparseM_1.77 setRNG_2013.9-1 Rcpp_0.12.12 htmlTable_1.9 foreign_0.8-69
[31] lava_1.5 prodlim_1.6.1 htmlwidgets_0.9 httr_1.3.1 ROI.plugin.quadprog_0.2-5
[36] RColorBrewer_1.1-2 acepack_1.4.1 pkgconfig_2.0.1 nnet_7.3-12 rlang_0.1.2
[41] reshape2_1.4.2 munsell_0.4.3 tools_3.4.1 ranger_0.8.0 devtools_1.13.3
[46] ROI.plugin.glpk_0.2-5 stringr_1.2.0 ModelMetrics_1.1.0 knitr_1.17 robustbase_0.92-7
[51] purrr_0.2.3 pbapply_1.3-3 nlme_3.1-131 quantreg_5.33 slam_0.1-40
[56] RcppRoll_0.2.2 compiler_3.4.1 pbkrtest_0.4-7 curl_2.6 e1071_1.6-8
[61] tibble_1.3.4 stringi_1.1.5 superpc_1.09 nloptr_1.0.4 gbm_2.1.3
[66] ucminf_1.1-4 R6_2.2.2 latticeExtra_0.6-28 gridExtra_2.3 codetools_0.2-15
[71] gtools_3.5.0 assertthat_0.2.0 CVST_0.2-1 Rvmmin_2017-7.18 optextras_2016-8.8
[76] withr_2.0.0 mgcv_1.8-17 Rcgmin_2013-2.21 quadprog_1.5-5 dfoptim_2016.7-1
[81] rpart_4.1-11 timeDate_3012.100 class_7.3-14 minqa_1.2.4 git2r_0.18.0
[86] numDeriv_2016.8-1 base64enc_0.1-3
Yikes! If you'd like to submit a PR to fix this, I'm happy to take this, otherwise I'll get to it when I have some free time
I've added a PR to fix the issue.