Extract groups from intercept names with spread/gather_draws
I would like to extract draws from a model that has an intercept that varies with a predictor with spread/gather_draws.
When the model is specified as y ~ 0 + group it results in a brms model with variables "b_groupA" and "b_groupB". But it is not possible to extract the group using m %>% gather_draws(b_group[group]). What it is possible is to extract the intercepts using regex but not the group index variable:
m %>%
gather_draws(`b_.*`, regex = TRUE)
# Groups: .variable [2]
.chain .iteration .draw .variable .value
<int> <int> <int> <chr> <dbl>
1 1 1 1 b_groupA 104.
2 1 2 2 b_groupA 104.
3 1 3 3 b_groupA 104.
4 1 4 4 b_groupA 104.
5 1 5 5 b_groupA 106.
6 1 6 6 b_groupA 105.
So far, the workaround for me has been to use separate_wider_regex but things get complex quite quickly.
Do you think your functions can be adapted to cover such scenario?
Code:
library(tidyverse)
library(brms)
library(tidybayes)
N <- list(a = 25, b = 30) # sample size
MEAN <- list(a = 105, b = 103) # population mean
SD <- list(a = 2, b = 5) # population sd
ya <- tibble(y = rnorm(N$a, mean = MEAN$a, sd = SD$a))
yb <- tibble(y = rnorm(N$b, mean = MEAN$b, sd = SD$b))
df <- bind_rows(list(A = ya, B = yb), .id = "group")
m <- brm(y ~ 0 + group, data = df)
get_variables(m)
# [1] "b_groupA" "b_groupB" "sigma" "lprior" "lp__" "accept_stat__" "treedepth__" "stepsize__" "divergent__" "n_leapfrog__" "energy__"
m %>%
gather_draws(`b_.*`, regex = TRUE)
Hmm, at some point I should probably support more arbitrary indexing schemes in those functions. Just added an issue for that here: #322
Unfortunately I don't have cycles for it at the moment (though happy to take PRs if someone wanted to tackle it), but I can suggest using rename_with with gsub to standardize names before using spread_draws() / gather_draws().
e.g. consider this data set:
set.seed(1234)
df = data.frame(
a = rnorm(1000),
b_groupA = rnorm(1000, 1),
b_groupB = rnorm(1000, 2)
) |>
posterior::as_draws_df()
df
#> # A draws_df: 1000 iterations, 1 chains, and 3 variables
#> a b_groupA b_groupB
#> 1 -1.21 -0.21 1.03
#> 2 0.28 1.30 1.90
#> 3 1.08 -0.54 1.89
#> 4 -2.35 1.64 3.19
#> 5 0.43 1.70 0.34
#> 6 0.51 -0.91 0.95
#> 7 -0.57 1.94 0.26
#> 8 -0.55 0.78 2.51
#> 9 -0.56 0.33 1.55
#> 10 -0.89 1.45 0.16
#> # ... with 990 more draws
#> # ... hidden reserved variables {'.chain', '.iteration', '.draw'}
You can rename the columns into the format tidybayes expects using something like this:
df |>
dplyr::rename_with(\(x) gsub("b_group(.*)", "b_group[\\1]", x))
#> # A draws_df: 1000 iterations, 1 chains, and 3 variables
#> a b_group[A] b_group[B]
#> 1 -1.21 -0.21 1.03
#> 2 0.28 1.30 1.90
#> 3 1.08 -0.54 1.89
#> 4 -2.35 1.64 3.19
#> 5 0.43 1.70 0.34
#> 6 0.51 -0.91 0.95
#> 7 -0.57 1.94 0.26
#> 8 -0.55 0.78 2.51
#> 9 -0.56 0.33 1.55
#> 10 -0.89 1.45 0.16
#> # ... with 990 more draws
#> # ... hidden reserved variables {'.chain', '.iteration', '.draw'}
Which can be chained right into spread_draws, a la:
df |>
dplyr::rename_with(\(x) gsub("b_group(.*)", "b_group[\\1]", x)) |>
tidybayes::spread_draws(b_group[i]) |>
ggdist::median_qi()
#> # A tibble: 2 × 7
#> i b_group .lower .upper .width .point .interval
#> <chr> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
#> 1 A 1.01 -0.973 2.84 0.95 median qi
#> 2 B 2.06 0.101 4.02 0.95 median qi