mcp icon indicating copy to clipboard operation
mcp copied to clipboard

stan backend

Open lindeloev opened this issue 4 years ago • 3 comments

mcp 2.0 will support stan in addition to JAGS. It is far out in the future but this issue collects working points.

  • [ ] Obviously, generate a stan model, pass data, and sample it.
  • [ ] Support bridgesampling-based Bayes Factors
  • [ ] Can jags-functions and stan-functions be dropped as dependencies, only to be installed upon first use? (call mcp(model, data, backend = "stan")). Otherwise, the dependencies would be quite heavy for non-JAGS and non-stan users.
  • [ ] Check if stan samples more effectively using a continuous step function, e.g., as in this post.
  • [ ] Option or default to no prior for non-intercept and non-changepoint parameters? Cf. #122.

lindeloev avatar Jan 05 '21 12:01 lindeloev

Awesome. Excited to see this on the roadmap. I'd love to contribute to this while still learning Bayesian modeling. Do you have any suggestion or contributor guide? I would be interested in implementing the python version with the PyMC3 backend as well.

jpzhangvincent avatar Jan 12 '21 17:01 jpzhangvincent

Thanks, @jpzhangvincent, that would be great! I think getting it to work is simply a matter of (a) re-writing a few JAGS models as stan models and learn if they work well and (b) write an R function that generate these from mcps internal representation of the model. I could really use some input on (a) here as my stan skills are limited.

mcp is under heavy internal restructuring and a few breaking changes, most of which is tracked in issue #90. I think it makes sense to wait until after that release when things hopefully settle down. But I think the JAGS-part is finished now. mcp 0.4 takes formulas like this:

model = list(
  y ~ 1 + x:group,
  ~ 0 + x,
  ~ 1 + sigma(1 + group)
)

which for data like

> head(df)
  x group         y          z
1 1     A -1.431554 -5.9042791
2 2     B 12.819796  1.6075971
3 3     C 17.218474  4.8689988
4 4     D  9.243459 -2.1581639
5 5     A  9.609940 10.1076712
6 6     B  9.544842  0.2298296

generates JAGS code like this:

model {
  # mcp helper values
  cp_0 = MINX
  cp_3 = MAXX

  # Priors for population-level effects
  cp_1 ~ dt(MINX, 1/((MAXX-MINX)/N_CP)^2, N_CP-1) T(cp_0, MAXX)
  cp_2 ~ dt(MINX, 1/((MAXX-MINX)/N_CP)^2, N_CP-1) T(cp_1, MAXX)
  Intercept_1 ~ dt(MEANLINKY, 1/(SDLINKY)^2, 3) 
  xgroupA_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  xgroupB_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  xgroupC_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  xgroupD_1 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  sigma_1 ~ dt(0, 1/(SDLINKY)^2, 3) T(0, )
  x_2 ~ dt(0, 1/(N_CP*SDLINKY/(MAXX-MINX))^2, 3) 
  Intercept_3 ~ dt(MEANLINKY, 1/(SDLINKY)^2, 3) 
  sigma_3 ~ dt(0, 1/(SDLINKY)^2, 3) T(0, )
  sigma_groupB_3 ~ dt(0, 1/(SDLINKY)^2, 3) 
  sigma_groupC_3 ~ dt(0, 1/(SDLINKY)^2, 3) 
  sigma_groupD_3 ~ dt(0, 1/(SDLINKY)^2, 3) 

  # Model and likelihood
  for (i_ in 1:length(x)) {
    # par_x local to each segment
    x_local_1_[i_] = min(x[i_], cp_1)
    x_local_2_[i_] = min(x[i_], cp_2) - cp_1
    x_local_3_[i_] = min(x[i_], cp_3) - cp_2
    
    # Formula for mu
    mu_[i_] =
    
      # Segment 1: y1 + x:group
      (x[i_] >= cp_0) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(1)], c(Intercept_1)) * 1 + 
      (x[i_] >= cp_0) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(2, 3, 4, 5)], c(xgroupA_1, xgroupB_1, xgroupC_1, xgroupD_1)) * x_local_1_[i_] + 
    
      # Segment 2: y ~ 10 + x
      (x[i_] >= cp_1) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(7)], c(x_2)) * x_local_2_[i_] + 
    
      # Segment 3: y ~ 11 + sigma(1 + group)
      (x[i_] >= cp_2) * inprod(rhs_data_[i_, c(8)], c(Intercept_3)) * 1
    
    # Formula for sigma
    sigma_[i_] = max(10^-9, sigma_tmp[i_])  # Count negative sigma as just-above-zero sigma
    sigma_tmp[i_] =  
      # Segment 1: y1 + x:group
      (x[i_] >= cp_0) * (x[i_] < cp_2) * inprod(rhs_data_[i_, c(6)], c(sigma_1)) * 1 + 
    
      # Segment 3: y ~ 11 + sigma(1 + group)
      (x[i_] >= cp_2) * inprod(rhs_data_[i_, c(9, 10, 11, 12)], c(sigma_3, sigma_groupB_3, sigma_groupC_3, sigma_groupD_3)) * 1

    # Likelihood and log-density for family = gaussian()
    y[i_] ~ dnorm((mu_[i_]), 1 / sigma_[i_]^2)  # SD as precision
    loglik_[i_] = logdensity.norm(y[i_], (mu_[i_]), 1 / sigma_[i_]^2)  # SD as precision
  }
}

Here, rhs_data_ is model.matrix but with x factored out of all terms. x is then "factored in" in JAGS, as you can see. inprod is simply equivalent to %*%* in base R.

Some of the work points for generating an equivalent stan model are:

  1. I think some of the priors can be dropped in stan (JAGS requires priors for everything).
  2. I think stan allows for vectorizing, so we can get rid of the for-loop.
  3. I have to learn more stan to see if some of it can be moved to a "data" chunk, etc.
  4. There are many identical ways to represent the formula-part, but JAGS samples considerably faster for this particular one. I'd like to see if stan is more robust so that we needn't have multiple lines of code for each segment.
  5. In general, how can this be made to run the most efficient in stan? Can we use some of the new primitives, can we make a model that runs on GPU, etc.?

Would love any tips, example stan models, or thoughts!

lindeloev avatar Jan 12 '21 23:01 lindeloev

As far as dependencies, you can:

  1. Have JAGS/Stan as suggested
  2. On startup
    • If neither is installed, give the user a message.
    • If only one is installed, set some options() to use that one.

mattansb avatar Mar 17 '21 06:03 mattansb