mcp
mcp copied to clipboard
stan backend
mcp 2.0 will support stan
in addition to JAGS. It is far out in the future but this issue collects working points.
- [ ] Obviously, generate a
stan
model, pass data, and sample it. - [ ] Support
bridgesampling
-based Bayes Factors - [ ] Can jags-functions and stan-functions be dropped as dependencies, only to be installed upon first use? (call
mcp(model, data, backend = "stan")
). Otherwise, the dependencies would be quite heavy for non-JAGS and non-stan users. - [ ] Check if
stan
samples more effectively using a continuous step function, e.g., as in this post. - [ ] Option or default to no prior for non-intercept and non-changepoint parameters? Cf. #122.
Awesome. Excited to see this on the roadmap. I'd love to contribute to this while still learning Bayesian modeling. Do you have any suggestion or contributor guide? I would be interested in implementing the python version with the PyMC3 backend as well.
Thanks, @jpzhangvincent, that would be great! I think getting it to work is simply a matter of (a) re-writing a few JAGS models as stan models and learn if they work well and (b) write an R function that generate these from mcp
s internal representation of the model. I could really use some input on (a) here as my stan
skills are limited.
mcp is under heavy internal restructuring and a few breaking changes, most of which is tracked in issue #90. I think it makes sense to wait until after that release when things hopefully settle down. But I think the JAGS-part is finished now. mcp 0.4 takes formulas like this:
model = list(
y ~ 1 + x:group,
~ 0 + x,
~ 1 + sigma(1 + group)
)
which for data like
> head(df)
x group y z
1 1 A -1.431554 -5.9042791
2 2 B 12.819796 1.6075971
3 3 C 17.218474 4.8689988
4 4 D 9.243459 -2.1581639
5 5 A 9.609940 10.1076712
6 6 B 9.544842 0.2298296
generates JAGS code like this:
model {
# mcp helper values
cp_0 = MINX
cp_3 = MAXX
# Priors for population-level effects
cp_1 ~ dt(MINX, 1/((MAXX-MINX)/N_CP)^2, N_CP-1) T(cp_0, MAXX)
cp_2 ~ dt(MINX, 1/((MAXX-MINX)/N_CP)^2, N_CP-1) T(cp_1, MAXX)
Intercept_1 ~ dt(MEANLINKY, 1/(SDLINKY)^2, 3)
xgroupA_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3)
xgroupB_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3)
xgroupC_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3)
xgroupD_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3)
sigma_1 ~ dt(0, 1/(SDLINKY)^2, 3) T(0, )
x_2 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3)
Intercept_3 ~ dt(MEANLINKY, 1/(SDLINKY)^2, 3)
sigma_3 ~ dt(0, 1/(SDLINKY)^2, 3) T(0, )
sigma_groupB_3 ~ dt(0, 1/(SDLINKY)^2, 3)
sigma_groupC_3 ~ dt(0, 1/(SDLINKY)^2, 3)
sigma_groupD_3 ~ dt(0, 1/(SDLINKY)^2, 3)
# Model and likelihood
for (i_ in 1:length(x)) {
# par_x local to each segment
x_local_1_[i_] = min(x[i_], cp_1)
x_local_2_[i_] = min(x[i_], cp_2) - cp_1
x_local_3_[i_] = min(x[i_], cp_3) - cp_2
# Formula for mu
mu_[i_] =
# Segment 1: y1 + x:group
(x[i_] >= cp_0) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(1)], c(Intercept_1)) * 1 +
(x[i_] >= cp_0) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(2, 3, 4, 5)], c(xgroupA_1, xgroupB_1, xgroupC_1, xgroupD_1)) * x_local_1_[i_] +
# Segment 2: y ~ 10 + x
(x[i_] >= cp_1) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(7)], c(x_2)) * x_local_2_[i_] +
# Segment 3: y ~ 11 + sigma(1 + group)
(x[i_] >= cp_2) * inprod(rhs_data_[i_, c(8)], c(Intercept_3)) * 1
# Formula for sigma
sigma_[i_] = max(10^-9, sigma_tmp[i_]) # Count negative sigma as just-above-zero sigma
sigma_tmp[i_] =
# Segment 1: y1 + x:group
(x[i_] >= cp_0) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(6)], c(sigma_1)) * 1 +
# Segment 3: y ~ 11 + sigma(1 + group)
(x[i_] >= cp_2) * inprod(rhs_data_[i_, c(9, 10, 11, 12)], c(sigma_3, sigma_groupB_3, sigma_groupC_3, sigma_groupD_3)) * 1
# Likelihood and log-density for family = gaussian()
y[i_] ~ dnorm((mu_[i_]), 1 / sigma_[i_]^2) # SD as precision
loglik_[i_] = logdensity.norm(y[i_], (mu_[i_]), 1 / sigma_[i_]^2) # SD as precision
}
}
Here, rhs_data_ is model.matrix
but with x
factored out of all terms. x
is then "factored in" in JAGS, as you can see. inprod
is simply equivalent to %*%*
in base R.
Some of the work points for generating an equivalent stan model are:
- I think some of the priors can be dropped in stan (JAGS requires priors for everything).
- I think stan allows for vectorizing, so we can get rid of the
for
-loop. - I have to learn more stan to see if some of it can be moved to a "data" chunk, etc.
- There are many identical ways to represent the formula-part, but JAGS samples considerably faster for this particular one. I'd like to see if stan is more robust so that we needn't have multiple lines of code for each segment.
- In general, how can this be made to run the most efficient in stan? Can we use some of the new primitives, can we make a model that runs on GPU, etc.?
Would love any tips, example stan models, or thoughts!
As far as dependencies, you can:
- Have JAGS/Stan as suggested
- On startup
- If neither is installed, give the user a message.
- If only one is installed, set some
options()
to use that one.