posterior
posterior copied to clipboard
Rollups of summaries of variables with indices, version 2
Summary
Now that some improved index-handling code is in, I thought I'd take a second attempt at rollups of summaries of variables with indices. This is a PR that will close #43 if it goes forward.
This is based on @jsocolar's #152, except I have iterated on the interface somewhat to be a bit more generic. It allows arbitrary rollup functions to be given on a per-original-summary-column basis, and supplies both overall default rollup functions and summary-specific rollup functions (e.g. having the ess functions rollup with min and rhat functions with max by default).
Demo:
x <- example_draws()
# you can roll up summaries of array-like variables by rolling up draws
# objects directly; this will apply the default options of summarise_draws()
rollup_summary(x)
#> <rollup_summary>:
#>
#> $unrolled (variables that have not been rolled up):
#> # A tibble: 2 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 mu 4.18 4.16 3.40 3.57 -0.854 9.39 1.02 558. 322.
#> 2 tau 4.16 3.07 3.58 2.89 0.309 11.0 1.01 246. 202.
#>
#> $rolled (variables that have been rolled up):
#> # A tibble: 1 × 17
#> variable dim mean_min mean_max median_min median_max sd_min sd_max mad_min
#> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 theta 8 3.04 6.75 3.72 5.97 4.63 6.80 4.25
#> # ℹ 8 more variables: mad_max <dbl>, q5_min <dbl>, q5_max <dbl>, q95_min <dbl>,
#> # q95_max <dbl>, rhat_max <dbl>, ess_bulk_min <dbl>, ess_tail_min <dbl>
# or summarise draws objects first to pick the desired summary measures
# (note that ess_bulk is only rolled up using min by default)
ds <- summarise_draws(x, "mean", "sd", "ess_bulk")
rollup_summary(ds)
#> <rollup_summary>:
#>
#> $unrolled (variables that have not been rolled up):
#> # A tibble: 2 × 4
#> variable mean sd ess_bulk
#> <chr> <dbl> <dbl> <dbl>
#> 1 mu 4.18 3.40 558.
#> 2 tau 4.16 3.58 246.
#>
#> $rolled (variables that have been rolled up):
#> # A tibble: 1 × 7
#> variable dim mean_min mean_max sd_min sd_max ess_bulk_min
#> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 theta 8 3.04 6.75 4.63 6.80 312.
# rollups work on variables of any dimension
x <- example_draws(example = "multi_normal")
rollup_summary(x)
#> <rollup_summary>:
#>
#> $rolled (variables that have been rolled up):
#> # A tibble: 2 × 17
#> variable dim mean_min mean_max median_min median_max sd_min sd_max mad_min
#> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 mu 3 0.0514 0.186 0.0575 0.184 0.112 0.314 0.131
#> 2 Sigma 3,3 -2.10 8.12 -2.11 8.02 0.165 0.946 0.173
#> # ℹ 8 more variables: mad_max <dbl>, q5_min <dbl>, q5_max <dbl>, q95_min <dbl>,
#> # q95_max <dbl>, rhat_max <dbl>, ess_bulk_min <dbl>, ess_tail_min <dbl>
# you can roll up only some variables
rollup_summary(x, variable = "Sigma")
#> <rollup_summary>:
#>
#> $unrolled (variables that have not been rolled up):
#> # A tibble: 3 × 10
#> variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
#> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 mu[1] 0.0514 0.0575 0.112 0.131 -0.130 0.225 1.01 677. 356.
#> 2 mu[2] 0.111 0.104 0.199 0.198 -0.208 0.449 1.00 566. 426.
#> 3 mu[3] 0.186 0.184 0.314 0.315 -0.322 0.715 1.02 650. 334.
#>
#> $rolled (variables that have been rolled up):
#> # A tibble: 1 × 17
#> variable dim mean_min mean_max median_min median_max sd_min sd_max mad_min
#> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 Sigma 3,3 -2.10 8.12 -2.11 8.02 0.165 0.946 0.173
#> # ℹ 8 more variables: mad_max <dbl>, q5_min <dbl>, q5_max <dbl>, q95_min <dbl>,
#> # q95_max <dbl>, rhat_max <dbl>, ess_bulk_min <dbl>, ess_tail_min <dbl>
# you can specify the rollup functions to apply to all summaries by passing
# unnamed parameters ...
rollup_summary(x, "mean", "min")
#> <rollup_summary>:
#>
#> $rolled (variables that have been rolled up):
#> # A tibble: 2 × 17
#> variable dim mean_mean mean_min median_mean median_min sd_mean sd_min
#> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 mu 3 0.116 0.0514 0.115 0.0575 0.208 0.112
#> 2 Sigma 3,3 1.01 -2.10 0.989 -2.11 0.387 0.165
#> # ℹ 9 more variables: mad_mean <dbl>, mad_min <dbl>, q5_mean <dbl>,
#> # q5_min <dbl>, q95_mean <dbl>, q95_min <dbl>, rhat_max <dbl>,
#> # ess_bulk_min <dbl>, ess_tail_min <dbl>
# ... or use names to specify rollup functions for specific summaries
rollup_summary(x, mean = "sd", median = "min")
#> <rollup_summary>:
#>
#> $rolled (variables that have been rolled up):
#> # A tibble: 2 × 15
#> variable dim mean_sd median_min sd_min sd_max mad_min mad_max q5_min q5_max
#> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 mu 3 0.0673 0.0575 0.112 0.314 0.131 0.315 -0.322 -0.130
#> 2 Sigma 3,3 3.19 -2.11 0.165 0.946 0.173 0.941 -2.87 6.71
#> # ℹ 5 more variables: q95_min <dbl>, q95_max <dbl>, rhat_max <dbl>,
#> # ess_bulk_min <dbl>, ess_tail_min <dbl>
# this unnamed for default / named for specific approach is also used to
# specify the default rollups in the `.funs` parameter. Its default value is:
default_rollups()
#> [[1]]
#> [1] "min" "max"
#>
#> $ess_basic
#> [1] "min"
#>
#> $ess_bulk
#> [1] "min"
#>
#> $ess_mean
#> [1] "min"
#>
#> $ess_median
#> [1] "min"
#>
#> $ess_quantile
#> [1] "min"
#>
#> $ess_sd
#> [1] "min"
#>
#> $ess_tail
#> [1] "min"
#>
#> $rhat
#> [1] "max"
#>
#> $rhat_basic
#> [1] "max"
#>
#> $rhat_nested
#> [1] "max"
# rollups can be chained to provide different rollup functions to
# different variables
x |>
summarise_draws("mean", "sd") |>
rollup_summary(variable = "mu", sd = "min") |>
rollup_summary(variable = "Sigma", sd = "max")
#> <rollup_summary>:
#>
#> $rolled (variables that have been rolled up):
#> # A tibble: 2 × 6
#> variable dim mean_min mean_max sd_min sd_max
#> <chr> <chr> <dbl> <dbl> <dbl> <dbl>
#> 1 mu 3 0.0514 0.186 0.112 NA
#> 2 Sigma 3,3 -2.10 8.12 NA 0.946
# you could ignore NAs on a specific rollup using an anonymous function,
# though is is perhaps a bit cludgy
x2 <- draws_rvars(x = c(rvar_rng(rnorm, 5), NA))
rollup_summary(x2, min)
#> <rollup_summary>:
#>
#> $rolled (variables that have been rolled up):
#> # A tibble: 1 × 11
#> variable dim mean_min median_min sd_min mad_min q5_min q95_min rhat_max
#> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 x 6 NA NA NA NA NA NA NA
#> # ℹ 2 more variables: ess_bulk_min <dbl>, ess_tail_min <dbl>
rollup_summary(x2, list(min = \(x) min(x, na.rm = TRUE)))
#> <rollup_summary>:
#>
#> $rolled (variables that have been rolled up):
#> # A tibble: 1 × 11
#> variable dim mean_min median_min sd_min mad_min q5_min q95_min rhat_max
#> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 x 6 -0.0320 -0.0405 0.985 0.985 -1.67 1.60 NA
#> # ℹ 2 more variables: ess_bulk_min <dbl>, ess_tail_min <dbl>
That last bit about NAs is IMO the weakest part of the API at the moment. Not sure if it needs fixing. One option is to add an overall na.rm = TRUE option to remove NAs before passing to the rollup functions. Could also allow NAs to be removed from specific summaries with a named argument; like na.rm = c(FALSE, ess_bulk = TRUE) or something like that.
Pinging folks who seemed interested in this from previous issues: @avehtari @paul-buerkner @jgabry @jsocolar @andrewgelman
Copyright and Licensing
By submitting this pull request, the copyright holder is agreeing to license the submitted work under the following licenses:
- Code: BSD 3-clause (https://opensource.org/licenses/BSD-3-Clause)
- Documentation: CC-BY 4.0 (https://creativecommons.org/licenses/by/4.0/)