brms icon indicating copy to clipboard operation
brms copied to clipboard

Identify vectorization problems of exposed functions from non-linear model with nested covariates

Open wds15 opened this issue 2 years ago • 3 comments

This prevent that predictions are made easily:

devtools::load_all("~/rwork/brms")
#> ℹ Loading brms
#> Loading required package: Rcpp
#> 
#> Loading 'brms' package (version 2.19.6). Useful instructions
#> can be found by typing help('brms'). A more detailed introduction
#> to the package is available through vignette('brms_overview').
set.seed(2134)
N <- 100
dat <- data.frame(y=rnorm(N))
dat$X <- matrix(rnorm(N*2), N, 2)

nlfun_stan <- "
  real nlfun(real a, real b, real c, row_vector X) {
     return a + b * X[1] + c * X[2];
  }
"
nlstanvar <- stanvar(scode = nlfun_stan, block = "functions")

# version for R post processing (should be created correctly by
# expose_functions below
#nlfun <- function(a, b, c, X) {
#  a + b * X[, , 1] + c * X[, , 2]
#}

# fit the model
bform <- bf(y~nlfun(a, b, c, X), a~1, b~1, c~1, nl = TRUE)
fit <- brm(bform, dat, stanvars = nlstanvar, refresh=0)
#> Compiling Stan program...
#> Start sampling
summary(fit)
#>  Family: gaussian 
#>   Links: mu = identity; sigma = identity 
#> Formula: y ~ nlfun(a, b, c, X) 
#>          a ~ 1
#>          b ~ 1
#>          c ~ 1
#>    Data: dat (Number of observations: 100) 
#>   Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
#>          total post-warmup draws = 4000
#> 
#> Population-Level Effects: 
#>             Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
#> a_Intercept    -0.02      0.10    -0.22     0.19 1.00     4600     2598
#> b_Intercept    -0.08      0.10    -0.28     0.11 1.00     4108     2977
#> c_Intercept     0.03      0.10    -0.17     0.24 1.00     4203     3152
#> 
#> Family Specific Parameters: 
#>       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
#> sigma     1.02      0.07     0.89     1.18 1.00     3853     3113
#> 
#> Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
#> and Tail_ESS are effective sample size measures, and Rhat is the potential
#> scale reduction factor on split chains (at convergence, Rhat = 1).

## export non-linear function for prediction
expose_functions(fit, vectorize=TRUE)

pp <- posterior_predict(fit)
#> Error: Error in (function (a, b, c, X, pstream__ = <pointer: 0x108a90af0>)  : 
#>   Exception: []: accessing element out of range. index 2 out of range; expecting index to be between 1 and 1; index position = 1X  (in 'unknown file name' at line 5)

Created on 2023-05-26 with reprex v2.0.2

wds15 avatar May 26 '23 14:05 wds15

How do you propose to fix this? brms merely vectorizes the Stan function via Vectorize.

paul-buerkner avatar May 27 '23 17:05 paul-buerkner

For the moment an error message saying that this case requires a custom function doing the by-draw work would be helpful. This is a special case not affecting too many users, probably.

Other than that one needs to come up with a clever extension of Vectorize... I am right now simply looping as solution. In case I come up with a generic solution which is straightforward, I am happy to drop it here.

wds15 avatar May 30 '23 07:05 wds15

Not sure brms can even identify why this custom function fails, so not sure if there is anything I can do about it in terms of error message. Leaving this issue as a feature for the future.

paul-buerkner avatar May 30 '23 21:05 paul-buerkner