rstantools
rstantools copied to clipboard
Relabelling helper function for working with factors in rstan
Summary:
When using factor variables (which are have attached meaningful level names), would it be useful/possible to have a function that relabels variables in the Stan fit object with the relevant factor levels?
Description:
When using categorical variables in stan you have to drop the level labels. This is fine and sensible, but with variables that have a number of levels, the output can be difficult to read. Would it be possible to write a helper function that relabels the numeric index with the relevant label for easier use? @jgabry thought it might be, but wasn't sure if it would be too complex to create something that was useful in the more general.
Reproducible Steps:
For example. With the 8 schools model the schools are coded as numbers. However, in real world data, they would often be recorded with meaningful names (here we use states).
schools_df <- data.frame(schools = state.name[1:8],
y = c(28, 8, -3, 7, -1, 1, 18, 12),
sigma = c(15, 10, 16, 11, 9, 11, 10, 18))
schools_df$schools = as.factor(schools_df$schools)
I am proposing a pair of functions that force the factors to numeric for use in rstan, and then relabels the relevant variables in the stan fit object. e.g. for the simple 8 schools case.
delabel <- function(x){
return(as.numeric(x))
}
relabel <- function(x,stan_fit,variables){
current_names <- names(stan_fit)
match <- cbind(levels(x),levels(as.factor(as.numeric(x))))
for (i in 1:length(variables)){
loc_var <- grepl(paste0("^",variables[i]),current_names)
for(j in 1:nrow(match)){
current_names[loc_var] <- gsub(pattern = match[j,2], replacement = match[j,1],current_names[loc_var])
}
}
names(stan_fit) <- current_names
return(stan_fit)
}
The idea is that is would work something like this:
schools_dat <- list(J = 8,
school = delabel(schools_df$schools),
y = schools_df$y,
sigma =schools_df$sigma)
library(rstan)
fit <- stan(file = '8schools.stan', data = schools_dat)
fit2 <- relabel(schools_df$schools,fit,c('theta','eta'))
colnames(as.matrix(fit2))
library(bayesplot)
mcmc_areas_ridges(as.matrix(fit2, pars = "theta"))
Not sure. You can currently get better names in the output if you pass a list (of lists of) initial values with names, which then get copied into the output.