rstanarm icon indicating copy to clipboard operation
rstanarm copied to clipboard

problems when predicting based on models build using R's library(parallel)

Open alexWhitworth opened this issue 7 years ago • 1 comments

Summary / Description:

I fit a number of models in parallel via parallel::clusterMap. When trying to use posterior_predict outside the cluster, I receive the following error

Error in formula.default(object, env = baseenv()) : invalid formula

This error can be traced to rstanarm:::.pp_data

Specifically, line 62 of above:

m <- model.frame(Terms, newdata, xlev = object$xlevels)

I'm not exactly sure if this comes down to an issue with base R's stats package or can be avoided within rstanarm.... It also may simply be too much of an edge-case to warrant effort in debugging.

Reproducible Steps:

## Note: this function requires library(data.table)

# @description Function for model fitting and evaluation of models built using 
# \code{\link[rstanarm]{stan_betareg}}. Models are first fit, then out-of-sample predictions
# are calculated, and finally accuracy metrics (RMSE and MAD) are calculated.
# It is intended that this function is used to do model training in paralle
# @param train A \code{data.table} containing the training data
# @param test A \code{data.table} containing the testing data
# @param iterations An integer scalar specifying the number of iterations for the model
# @param warmup A numeric scalar specifying the fraction of iterations to be used as warmup samples
# @param chains An integer scalar indicating the number of model chains to utilize. Defaults to \code{1L}
# @param draws An integer scalar. Passed to / See \code{\link[rstanarm]posterior_predict}.
# @return a \code{list} containing the model fit, out of sample predictions, and accuracy metrics
model_fit <- function(train, test, iterations, warmup, chains= 1L, draws= 1000L) {
  mod <- tryCatch(
    rstanarm::stan_betareg(
      formula= y ~ x | z,
      data= train, link = "logit", link.phi = "log",
      prior= normal(), prior_intercept= normal(), prior_phi= exponential(),
      iter= iterations, warmup = floor(iterations * warmup),
      chains= chains, cores= chains),
    error = function(c) return(list('error', conditionMessage(c)))
    )
  
  draws <- min(draws, floor(iterations * (1 - warmup)))
  pp <- tryCatch(
    rstanarm::posterior_predict(mod, newdata= test, draws= draws, re.form= NULL, allow.new.levels= TRUE),
    error = function(c) return(list('error', conditionMessage(c)))
  )
  test[, mean_pp := colMeans(pp)]
  
  return(list(model_fit= mod, oos_results= test))
}

library(data.table)
library(rstanarm)
library(parallel)

SEED <- 1234
set.seed(SEED)
eta <- c(1, -0.2)
gamma <- c(1.8, 0.4)
N <- 200
x <- rnorm(N, 2, 2)
z <- rnorm(N, 0, 2)
mu <- binomial(link = logit)$linkinv(eta[1] + eta[2]*x)
phi <- binomial(link = log)$linkinv(gamma[1] + gamma[2]*z)
y <- rbeta(N, mu * phi, (1 - mu) * phi)
dat <- data.frame(cbind(y, x, z))
setDT(dat)

## split data into train and test
idx <- sample.int(N, size= floor(N * 0.9), replace= FALSE)
train <- dat[idx,]
test <- dat[-idx,]


## setup parallel evaluation
chains <- 1L
iterations <- as.list(rep(c(4000,6000), each= 2))
warmup <- as.list(rep(c(0.6,0.7), 2))

nnodes <- 4L
cl <- parallel::makeCluster(nnodes, type= "FORK")

clusterEvalQ(cl, {
  library(data.table)
  library(rstanarm)
})
clusterExport(cl, varlist= c("iterations", "warmup", "chains", "train", "test", "model_fit"))

test_ex <- parallel::clusterMap(cl, RECYCLE= TRUE, SIMPLIFY= FALSE, .scheduling= "dynamic",
                              fun= model_fit, 
                              train= list(train), test= list(test),
                              iterations= iterations, warmup= warmup, chains= chains)
parallel::stopCluster(cl)

## highligh error
new_pred <- rstanarm::posterior_predict(
  test_ex[[1]]$model_fit, newdata= test, draws= draws, re.form= NULL, allow.new.levels= TRUE)

> Error in formula.default(object, env = baseenv()) : invalid formula

RStanARM Version:

rstanarm_2.17.4

R / OS system info:

> sessionInfo()
R version 3.4.1 (2017-06-30)
Platform: x86_64-redhat-linux-gnu (64-bit)
Running under: CentOS Linux 7 (Core)

Matrix products: default
BLAS/LAPACK: /usr/local/lib/libtatlas.so

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
 [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8    LC_PAPER=en_US.UTF-8       LC_NAME=C                 
 [9] LC_ADDRESS=C               LC_TELEPHONE=C             LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] parallel  stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] rstanarm_2.17.4   Rcpp_0.12.18      data.table_1.11.4

loaded via a namespace (and not attached):
 [1] lattice_0.20-35    zoo_1.8-3          gtools_3.8.1       assertthat_0.2.0   digest_0.6.15     
 [6] mime_0.5           R6_2.2.2           plyr_1.8.4         ggridges_0.5.0     stats4_3.4.1      
[11] ggplot2_3.0.0      colourpicker_1.0   pillar_1.3.0       rlang_0.2.1        lazyeval_0.2.1    
[16] minqa_1.2.4        rstudioapi_0.7     miniUI_0.1.1       nloptr_1.0.4       Matrix_1.2-14     
[21] DT_0.4             shinythemes_1.1.1  splines_3.4.1      shinyjs_1.0        lme4_1.1-17       
[26] stringr_1.3.1      loo_2.0.0          htmlwidgets_1.2    igraph_1.1.2       munsell_0.4.3     
[31] shiny_1.1.0        compiler_3.4.1     httpuv_1.4.5       rstan_2.17.3       pkgconfig_2.0.1   
[36] base64enc_0.1-3    rstantools_1.5.0   htmltools_0.3.6    tidyselect_0.2.4   tibble_1.4.2      
[41] gridExtra_2.3      codetools_0.2-15   matrixStats_0.52.2 threejs_0.3.1      crayon_1.3.4      
[46] dplyr_0.7.6        later_0.7.3        MASS_7.3-47        grid_3.4.1         nlme_3.1-131      
[51] xtable_1.8-2       gtable_0.2.0       magrittr_1.5       StanHeaders_2.17.2 scales_0.5.0      
[56] stringi_1.2.4      reshape2_1.4.3     promises_1.0.1     bindrcpp_0.2.2     dygraphs_1.1.1.6  
[61] xts_0.11-0         tools_3.4.1        glue_1.3.0         shinystan_2.5.0    markdown_0.8      
[66] purrr_0.2.5        crosstalk_1.0.0    survival_2.41-3    rsconnect_0.8.8    yaml_2.2.0        
[71] inline_0.3.14      colorspace_1.3-2   bayesplot_1.5.0    bindr_0.1.1    

alexWhitworth avatar Aug 30 '18 17:08 alexWhitworth

Thanks for reporting this. I just noticed nobody responded. Sorry about that!

I’m not sure if we should be doing something different to avoid this, but I’d like it to work if possible. I’ll try to look into it when I have time i.e. when I’m stuck on or bored with something more pressing )

jgabry avatar May 09 '19 19:05 jgabry