MCMCChains.jl icon indicating copy to clipboard operation
MCMCChains.jl copied to clipboard

Extracting matrix- or array-valued parameters

Open sethaxen opened this issue 3 years ago • 2 comments

Consider the following Chains object with a matrix-valued parameter (i.e. a "group") x:

julia> using MCMCChains

julia> chain = Chains(randn(100,6,4), [Symbol("x[$i,$j]") for j in 1:2 for i in 1:3])
Chains MCMC chain (100×6×4 Array{Float64, 3}):

Iterations        = 1:100
Thinning interval = 1
Chains            = 1, 2, 3, 4
Samples per chain = 100
parameters        = x[1,1], x[2,1], x[3,1], x[1,2], x[2,2], x[3,2]

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat 
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64 

      x[1,1]    0.0373    0.9782     0.0489    0.0509   423.8813    0.9969
      x[2,1]    0.0940    0.9706     0.0485    0.0152   445.5083    0.9928
      x[3,1]    0.0662    1.0092     0.0505    0.0285   394.6333    0.9936
      x[1,2]   -0.0166    1.0098     0.0505    0.0272   444.4208    0.9966
      x[2,2]    0.0330    1.0037     0.0502    0.0494   326.8740    0.9956
      x[3,2]    0.0687    0.9721     0.0486    0.0558   383.7953    1.0014

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

      x[1,1]   -1.8419   -0.5689   -0.0207    0.6834    2.1606
      x[2,1]   -1.8067   -0.5277    0.0747    0.7200    2.1326
      x[3,1]   -1.9482   -0.6006    0.0993    0.6744    2.1944
      x[1,2]   -2.0333   -0.7734   -0.0220    0.7360    1.8414
      x[2,2]   -1.8955   -0.6085    0.0160    0.6718    2.0019
      x[3,2]   -1.8967   -0.5349    0.1117    0.7130    2.0530

It would be really nice to have a way to retrieve the samples of x as a multidimensional array. Currently, we can use

julia> group(chain, :x).value
3-dimensional AxisArray{Float64,3,...} with axes:
    :iter, 1:1:100
    :var, [Symbol("x[1,1]"), Symbol("x[2,1]"), Symbol("x[3,1]"), Symbol("x[1,2]"), Symbol("x[2,2]"), Symbol("x[3,2]")]
    :chain, 1:4
And data, a 100×6×4 Array{Float64, 3}:
[:, :, 1] =
...

or

julia> get_params(chain) |> typeof
NamedTuple{(:x,), Tuple{NTuple{6, AxisArrays.AxisArray{Float64, 2, Matrix{Float64}, Tuple{AxisArrays.Axis{:iter, StepRange{Int64, Int64}}, AxisArrays.Axis{:chain, UnitRange{Int64}}}}}}}

both of which flatten the matrix dimension. It would be nice to have a convenience function where when possible such variables are returned as Arrays. e.g., in this case it might return something like

julia> (; x = AxisArray(reshape(group(chain, :x).value, (100, 3, 2, 4)), :iter, :dim_1, :dim_2, :chain))

sethaxen avatar May 16 '21 14:05 sethaxen

Fundamentally, it seems Chains is the wrong datastructure for such operations since it encodes the array structure only in the parameter names but parameter names can be changed freely and parameters can be reordered arbitrarily.

However, what about something like

asarray(chain, :x, (3, 2))

that would output a Array{Union{Missing,Float64},4} of size (niterations, 3, 2, nchains) that is filled with the values of x[1,1], ..., x[3,2], if existent?

devmotion avatar May 16 '21 15:05 devmotion

However, what about something like

asarray(chain, :x, (3, 2))

that would output a Array{Union{Missing,Float64},4} of size (niterations, 3, 2, nchains) that is filled with the values of x[1,1], ..., x[3,2], if existent?

Oh holy shit that's what I wanted to make before. I've wanted to have this function for years now and this sounds basically exactly like what I'd want.

cpfiffer avatar May 17 '21 02:05 cpfiffer