dbarts icon indicating copy to clipboard operation
dbarts copied to clipboard

Thinning dbartsSampler in a custom Gibbs Sampler

Open jacobenglert opened this issue 2 years ago • 0 comments

I'm curious if there is a good way to thin a dbartsSampler that exists a part of a larger custom MCMC. The existing n.thin argument to dbartsControl applies the thinning internally, so that the sampler will only return n.samples when it is run, with the thinning occurring in the background. In practice, it would be helpful to alternate thinning the BART component with other parameters in the model within each iteration.

As an example, consider the following snippet of code for rbart_vi (specifically, rbart_vi_run):

for (i in seq_len(n.samples)) {
  # update ranef
  resid <- with(state, y.st - treeFit.train)
  post.var <- 1.0 / (n.g / state$sigma^2.0 + 1.0 / state$tau^2.0)
  post.mean <- (n.g / state$sigma^2.0) * sapply(seq_len(numRanef), function(j) mean(resid[g.sel[[j]]])) * post.var
  ranef <- rnorm(numRanef, post.mean, sqrt(post.var))
  ranef.vec <- ranef[g]
  
  # update BART params
  sampler$setOffset(ranef.vec + if (!is.null(offset.orig)) offset.orig else 0, isWarmup)
  dbarts_samples <- sampler$run(0L, 1L)
  state$treeFit.train <- as.vector(dbarts_samples$train) - ranef.vec
  if (control@binary) sampler$getLatents(state$y.st)
  state$sigma <- dbarts_samples$sigma[1L]
  
  # update sd of ranef
  evalEnv$b.sq <- sum(ranef^2.0)
  state$tau <- sliceSample(posteriorClosure, state$tau, [email protected], boundary = c(0.0, Inf))[[email protected]]

...
}

It appears the random intercepts are only sampled n.samples times in total, not n.samples $\times$ n.thin times. I think it would be helpful to have an option you can switch on and off in-between iterations to keep (or not keep) the trees in a dbartsSampler from each individual call to sampler$run(0L, 1L). As it stands, a sampler will forget all of its previously stored trees if we set keepTrees = FALSE.

See the following reprex for an idea:

# Thinning reprex

# Simulate Data
set.seed(1)
n <- 100
p <- 5
x <- t(replicate(n, runif(p, 0, 1)))
f <- function(x){
  10*sin(pi*x[,1]*x[,2]) + 20*(x[,3] - .5)^2 + 10*x[,4] + 5*x[,5]
}

y <- rnorm(n, f(x), 1)

# Parameters
n.iter <- 1000
n.samples <- 200
n.thin <- 5

# Create dbarts sampler object
library(dbarts)
control <- dbartsControl(n.trees = 10, n.samples = n.samples, n.burn = 0,
                         n.chains = 1, keepTrees = TRUE, keepTrainingFits = TRUE,
                         updateState = TRUE, verbose = FALSE)

sampler <- dbarts(x, y, control = control)

# Run sampler for 1000 total iterations (thinned to 200 posterior samples)
for(k in seq_len(n.iter)){
  
  # Only keep the trees for every 5th iteration
  if(k %% n.thin == 0){
    control@keepTrees <- TRUE
    sampler$setControl(control)
  } else{
    control@keepTrees <- FALSE
    sampler$setControl(control)
  }
  sampler$run(0, 1)
}

# The result is the correct dimension (100 x 200), but only 1 sample (the last one) is
# stored because every time keepTrees is set to FALSE, the sampler forgets all
# of the trees it had stored previously
dim(sampler$predict(x))

There may be a clever workaround with existing functionality (copies, etc.), but I haven't managed to make it work. My current solution is to just save all of the trees, make predictions from the model fit, and then filter out the predictions corresponding to the samples I wanted to "thin away".

jacobenglert avatar Jan 05 '24 22:01 jacobenglert