MCMCChains.jl
MCMCChains.jl copied to clipboard
Extracting matrix- or array-valued parameters
Consider the following Chains
object with a matrix-valued parameter (i.e. a "group") x
:
julia> using MCMCChains
julia> chain = Chains(randn(100,6,4), [Symbol("x[$i,$j]") for j in 1:2 for i in 1:3])
Chains MCMC chain (100×6×4 Array{Float64, 3}):
Iterations = 1:100
Thinning interval = 1
Chains = 1, 2, 3, 4
Samples per chain = 100
parameters = x[1,1], x[2,1], x[3,1], x[1,2], x[2,2], x[3,2]
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Float64 Float64 Float64
x[1,1] 0.0373 0.9782 0.0489 0.0509 423.8813 0.9969
x[2,1] 0.0940 0.9706 0.0485 0.0152 445.5083 0.9928
x[3,1] 0.0662 1.0092 0.0505 0.0285 394.6333 0.9936
x[1,2] -0.0166 1.0098 0.0505 0.0272 444.4208 0.9966
x[2,2] 0.0330 1.0037 0.0502 0.0494 326.8740 0.9956
x[3,2] 0.0687 0.9721 0.0486 0.0558 383.7953 1.0014
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
x[1,1] -1.8419 -0.5689 -0.0207 0.6834 2.1606
x[2,1] -1.8067 -0.5277 0.0747 0.7200 2.1326
x[3,1] -1.9482 -0.6006 0.0993 0.6744 2.1944
x[1,2] -2.0333 -0.7734 -0.0220 0.7360 1.8414
x[2,2] -1.8955 -0.6085 0.0160 0.6718 2.0019
x[3,2] -1.8967 -0.5349 0.1117 0.7130 2.0530
It would be really nice to have a way to retrieve the samples of x
as a multidimensional array. Currently, we can use
julia> group(chain, :x).value
3-dimensional AxisArray{Float64,3,...} with axes:
:iter, 1:1:100
:var, [Symbol("x[1,1]"), Symbol("x[2,1]"), Symbol("x[3,1]"), Symbol("x[1,2]"), Symbol("x[2,2]"), Symbol("x[3,2]")]
:chain, 1:4
And data, a 100×6×4 Array{Float64, 3}:
[:, :, 1] =
...
or
julia> get_params(chain) |> typeof
NamedTuple{(:x,), Tuple{NTuple{6, AxisArrays.AxisArray{Float64, 2, Matrix{Float64}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}}}}
both of which flatten the matrix dimension. It would be nice to have a convenience function where when possible such variables are returned as Array
s. e.g., in this case it might return something like
julia> (; x = AxisArray(reshape(group(chain, :x).value, (100, 3, 2, 4)), :iter, :dim_1, :dim_2, :chain))
Fundamentally, it seems Chains is the wrong datastructure for such operations since it encodes the array structure only in the parameter names but parameter names can be changed freely and parameters can be reordered arbitrarily.
However, what about something like
asarray(chain, :x, (3, 2))
that would output a Array{Union{Missing,Float64},4}
of size (niterations, 3, 2, nchains)
that is filled with the values of x[1,1]
, ..., x[3,2]
, if existent?
However, what about something like
asarray(chain, :x, (3, 2))
that would output a
Array{Union{Missing,Float64},4}
of size(niterations, 3, 2, nchains)
that is filled with the values ofx[1,1]
, ...,x[3,2]
, if existent?
Oh holy shit that's what I wanted to make before. I've wanted to have this function for years now and this sounds basically exactly like what I'd want.