rstanarm
rstanarm copied to clipboard
problems when predicting based on models build using R's library(parallel)
Summary / Description:
I fit a number of models in parallel via parallel::clusterMap. When trying to use posterior_predict outside the cluster, I receive the following error
Error in formula.default(object, env = baseenv()) : invalid formula
This error can be traced to rstanarm:::.pp_data
Specifically, line 62 of above:
m <- model.frame(Terms, newdata, xlev = object$xlevels)
I'm not exactly sure if this comes down to an issue with base R's stats package or can be avoided within rstanarm.... It also may simply be too much of an edge-case to warrant effort in debugging.
Reproducible Steps:
## Note: this function requires library(data.table)
# @description Function for model fitting and evaluation of models built using
# \code{\link[rstanarm]{stan_betareg}}. Models are first fit, then out-of-sample predictions
# are calculated, and finally accuracy metrics (RMSE and MAD) are calculated.
# It is intended that this function is used to do model training in paralle
# @param train A \code{data.table} containing the training data
# @param test A \code{data.table} containing the testing data
# @param iterations An integer scalar specifying the number of iterations for the model
# @param warmup A numeric scalar specifying the fraction of iterations to be used as warmup samples
# @param chains An integer scalar indicating the number of model chains to utilize. Defaults to \code{1L}
# @param draws An integer scalar. Passed to / See \code{\link[rstanarm]posterior_predict}.
# @return a \code{list} containing the model fit, out of sample predictions, and accuracy metrics
model_fit <- function(train, test, iterations, warmup, chains= 1L, draws= 1000L) {
mod <- tryCatch(
rstanarm::stan_betareg(
formula= y ~ x | z,
data= train, link = "logit", link.phi = "log",
prior= normal(), prior_intercept= normal(), prior_phi= exponential(),
iter= iterations, warmup = floor(iterations * warmup),
chains= chains, cores= chains),
error = function(c) return(list('error', conditionMessage(c)))
)
draws <- min(draws, floor(iterations * (1 - warmup)))
pp <- tryCatch(
rstanarm::posterior_predict(mod, newdata= test, draws= draws, re.form= NULL, allow.new.levels= TRUE),
error = function(c) return(list('error', conditionMessage(c)))
)
test[, mean_pp := colMeans(pp)]
return(list(model_fit= mod, oos_results= test))
}
library(data.table)
library(rstanarm)
library(parallel)
SEED <- 1234
set.seed(SEED)
eta <- c(1, -0.2)
gamma <- c(1.8, 0.4)
N <- 200
x <- rnorm(N, 2, 2)
z <- rnorm(N, 0, 2)
mu <- binomial(link = logit)$linkinv(eta[1] + eta[2]*x)
phi <- binomial(link = log)$linkinv(gamma[1] + gamma[2]*z)
y <- rbeta(N, mu * phi, (1 - mu) * phi)
dat <- data.frame(cbind(y, x, z))
setDT(dat)
## split data into train and test
idx <- sample.int(N, size= floor(N * 0.9), replace= FALSE)
train <- dat[idx,]
test <- dat[-idx,]
## setup parallel evaluation
chains <- 1L
iterations <- as.list(rep(c(4000,6000), each= 2))
warmup <- as.list(rep(c(0.6,0.7), 2))
nnodes <- 4L
cl <- parallel::makeCluster(nnodes, type= "FORK")
clusterEvalQ(cl, {
library(data.table)
library(rstanarm)
})
clusterExport(cl, varlist= c("iterations", "warmup", "chains", "train", "test", "model_fit"))
test_ex <- parallel::clusterMap(cl, RECYCLE= TRUE, SIMPLIFY= FALSE, .scheduling= "dynamic",
fun= model_fit,
train= list(train), test= list(test),
iterations= iterations, warmup= warmup, chains= chains)
parallel::stopCluster(cl)
## highligh error
new_pred <- rstanarm::posterior_predict(
test_ex[[1]]$model_fit, newdata= test, draws= draws, re.form= NULL, allow.new.levels= TRUE)
> Error in formula.default(object, env = baseenv()) : invalid formula
RStanARM Version:
rstanarm_2.17.4
R / OS system info:
> sessionInfo()
R version 3.4.1 (2017-06-30)
Platform: x86_64-redhat-linux-gnu (64-bit)
Running under: CentOS Linux 7 (Core)
Matrix products: default
BLAS/LAPACK: /usr/local/lib/libtatlas.so
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 LC_PAPER=en_US.UTF-8 LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] parallel stats graphics grDevices utils datasets methods base
other attached packages:
[1] rstanarm_2.17.4 Rcpp_0.12.18 data.table_1.11.4
loaded via a namespace (and not attached):
[1] lattice_0.20-35 zoo_1.8-3 gtools_3.8.1 assertthat_0.2.0 digest_0.6.15
[6] mime_0.5 R6_2.2.2 plyr_1.8.4 ggridges_0.5.0 stats4_3.4.1
[11] ggplot2_3.0.0 colourpicker_1.0 pillar_1.3.0 rlang_0.2.1 lazyeval_0.2.1
[16] minqa_1.2.4 rstudioapi_0.7 miniUI_0.1.1 nloptr_1.0.4 Matrix_1.2-14
[21] DT_0.4 shinythemes_1.1.1 splines_3.4.1 shinyjs_1.0 lme4_1.1-17
[26] stringr_1.3.1 loo_2.0.0 htmlwidgets_1.2 igraph_1.1.2 munsell_0.4.3
[31] shiny_1.1.0 compiler_3.4.1 httpuv_1.4.5 rstan_2.17.3 pkgconfig_2.0.1
[36] base64enc_0.1-3 rstantools_1.5.0 htmltools_0.3.6 tidyselect_0.2.4 tibble_1.4.2
[41] gridExtra_2.3 codetools_0.2-15 matrixStats_0.52.2 threejs_0.3.1 crayon_1.3.4
[46] dplyr_0.7.6 later_0.7.3 MASS_7.3-47 grid_3.4.1 nlme_3.1-131
[51] xtable_1.8-2 gtable_0.2.0 magrittr_1.5 StanHeaders_2.17.2 scales_0.5.0
[56] stringi_1.2.4 reshape2_1.4.3 promises_1.0.1 bindrcpp_0.2.2 dygraphs_1.1.1.6
[61] xts_0.11-0 tools_3.4.1 glue_1.3.0 shinystan_2.5.0 markdown_0.8
[66] purrr_0.2.5 crosstalk_1.0.0 survival_2.41-3 rsconnect_0.8.8 yaml_2.2.0
[71] inline_0.3.14 colorspace_1.3-2 bayesplot_1.5.0 bindr_0.1.1
Thanks for reporting this. I just noticed nobody responded. Sorry about that!
I’m not sure if we should be doing something different to avoid this, but I’d like it to work if possible. I’ll try to look into it when I have time i.e. when I’m stuck on or bored with something more pressing )