Identify vectorization problems of exposed functions from non-linear model with nested covariates
This prevent that predictions are made easily:
devtools::load_all("~/rwork/brms")
#> ℹ Loading brms
#> Loading required package: Rcpp
#>
#> Loading 'brms' package (version 2.19.6). Useful instructions
#> can be found by typing help('brms'). A more detailed introduction
#> to the package is available through vignette('brms_overview').
set.seed(2134)
N <- 100
dat <- data.frame(y=rnorm(N))
dat$X <- matrix(rnorm(N*2), N, 2)
nlfun_stan <- "
real nlfun(real a, real b, real c, row_vector X) {
return a + b * X[1] + c * X[2];
}
"
nlstanvar <- stanvar(scode = nlfun_stan, block = "functions")
# version for R post processing (should be created correctly by
# expose_functions below
#nlfun <- function(a, b, c, X) {
# a + b * X[, , 1] + c * X[, , 2]
#}
# fit the model
bform <- bf(y~nlfun(a, b, c, X), a~1, b~1, c~1, nl = TRUE)
fit <- brm(bform, dat, stanvars = nlstanvar, refresh=0)
#> Compiling Stan program...
#> Start sampling
summary(fit)
#> Family: gaussian
#> Links: mu = identity; sigma = identity
#> Formula: y ~ nlfun(a, b, c, X)
#> a ~ 1
#> b ~ 1
#> c ~ 1
#> Data: dat (Number of observations: 100)
#> Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
#> total post-warmup draws = 4000
#>
#> Population-Level Effects:
#> Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
#> a_Intercept -0.02 0.10 -0.22 0.19 1.00 4600 2598
#> b_Intercept -0.08 0.10 -0.28 0.11 1.00 4108 2977
#> c_Intercept 0.03 0.10 -0.17 0.24 1.00 4203 3152
#>
#> Family Specific Parameters:
#> Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
#> sigma 1.02 0.07 0.89 1.18 1.00 3853 3113
#>
#> Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
#> and Tail_ESS are effective sample size measures, and Rhat is the potential
#> scale reduction factor on split chains (at convergence, Rhat = 1).
## export non-linear function for prediction
expose_functions(fit, vectorize=TRUE)
pp <- posterior_predict(fit)
#> Error: Error in (function (a, b, c, X, pstream__ = <pointer: 0x108a90af0>) :
#> Exception: []: accessing element out of range. index 2 out of range; expecting index to be between 1 and 1; index position = 1X (in 'unknown file name' at line 5)
Created on 2023-05-26 with reprex v2.0.2
How do you propose to fix this? brms merely vectorizes the Stan function via Vectorize.
For the moment an error message saying that this case requires a custom function doing the by-draw work would be helpful. This is a special case not affecting too many users, probably.
Other than that one needs to come up with a clever extension of Vectorize... I am right now simply looping as solution. In case I come up with a generic solution which is straightforward, I am happy to drop it here.
Not sure brms can even identify why this custom function fails, so not sure if there is anything I can do about it in terms of error message. Leaving this issue as a feature for the future.