rstantools icon indicating copy to clipboard operation
rstantools copied to clipboard

Relabelling helper function for working with factors in rstan

Open lauken13 opened this issue 6 years ago • 1 comments

Summary:

When using factor variables (which are have attached meaningful level names), would it be useful/possible to have a function that relabels variables in the Stan fit object with the relevant factor levels?

Description:

When using categorical variables in stan you have to drop the level labels. This is fine and sensible, but with variables that have a number of levels, the output can be difficult to read. Would it be possible to write a helper function that relabels the numeric index with the relevant label for easier use? @jgabry thought it might be, but wasn't sure if it would be too complex to create something that was useful in the more general.

Reproducible Steps:

For example. With the 8 schools model the schools are coded as numbers. However, in real world data, they would often be recorded with meaningful names (here we use states).

schools_df <- data.frame(schools = state.name[1:8],
y = c(28,  8, -3,  7, -1,  1, 18, 12),
sigma = c(15, 10, 16, 11,  9, 11, 10, 18))
schools_df$schools = as.factor(schools_df$schools)

I am proposing a pair of functions that force the factors to numeric for use in rstan, and then relabels the relevant variables in the stan fit object. e.g. for the simple 8 schools case.

delabel <- function(x){
  return(as.numeric(x))
}

relabel <- function(x,stan_fit,variables){
   current_names <- names(stan_fit)
   match <- cbind(levels(x),levels(as.factor(as.numeric(x))))
   for (i in 1:length(variables)){
     loc_var <- grepl(paste0("^",variables[i]),current_names)
     for(j in 1:nrow(match)){
       current_names[loc_var] <- gsub(pattern = match[j,2], replacement = match[j,1],current_names[loc_var])
     }
   }
   names(stan_fit) <- current_names
   return(stan_fit)
}

The idea is that is would work something like this:

schools_dat <- list(J = 8, 
                    school = delabel(schools_df$schools),
                    y = schools_df$y,
                    sigma =schools_df$sigma)

library(rstan)

fit <- stan(file = '8schools.stan', data = schools_dat)

fit2 <- relabel(schools_df$schools,fit,c('theta','eta'))

colnames(as.matrix(fit2))
library(bayesplot)
mcmc_areas_ridges(as.matrix(fit2, pars = "theta"))

lauken13 avatar Jan 17 '19 22:01 lauken13

Not sure. You can currently get better names in the output if you pass a list (of lists of) initial values with names, which then get copied into the output.

bgoodri avatar Mar 21 '19 00:03 bgoodri