ArviZ.jl
ArviZ.jl copied to clipboard
Supporting more MCMCChains variable names
Currently from_mcmcchains assumes that all variable names are single-bracket-delimited or dot-delimited. However, quite complicated names are possible:
julia> using Turing, LinearAlgebra
julia> @model function foo()
a ~ Normal()
bar = (b=Matrix{typeof(a)}(undef, 2, 3), c=Vector{typeof(a)}(undef, 3))
bar.b[1, :] .~ Normal(a)
bar.b[2, 1:2] ~ MvNormal(I(2))
bar.b[2, 3:end] .~ Normal(a)
bar.c[end:-1:1] .~ Normal(a)
end
foo (generic function with 2 methods)
julia> chns = sample(foo(), NUTS(), 1_000)
┌ Info: Found initial step size
└ ϵ = 1.6
Sampling 100%|█████████████████████████████████████████████████████████████████████████| Time: 0:00:01
Chains MCMC chain (1000×22×1 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 1
Samples per chain = 1000
Wall duration = 1.21 seconds
Compute duration = 1.21 seconds
parameters = a, bar.b[1,:][1], bar.b[1,:][2], bar.b[1,:][3], bar.b[2,1:2][1], bar.b[2,1:2][2], bar.b[2,3:3][1], bar.c[3:-1:1][1], bar.c[3:-1:1][2], bar.c[3:-1:1][3]
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
a 0.0278 0.9966 0.0315 0.0724 139.5528 1.0152 115.1426
bar.b[1,:][1] 0.0178 1.4143 0.0447 0.0841 205.0380 1.0096 169.1733
bar.b[1,:][2] 0.0268 1.3678 0.0433 0.0765 292.3916 1.0072 241.2472
bar.b[1,:][3] 0.0395 1.3866 0.0438 0.0852 234.2846 1.0061 193.3041
bar.b[2,1:2][1] 0.0395 1.0091 0.0319 0.0260 1017.5267 0.9990 839.5435
bar.b[2,1:2][2] -0.0270 0.9839 0.0311 0.0260 1495.1411 0.9990 1233.6147
bar.b[2,3:3][1] 0.0506 1.3798 0.0436 0.0724 347.4877 1.0032 286.7061
bar.c[3:-1:1][1] 0.0156 1.4357 0.0454 0.0835 275.8889 1.0079 227.6311
bar.c[3:-1:1][2] -0.0010 1.4186 0.0449 0.0720 362.6178 1.0072 299.1896
bar.c[3:-1:1][3] 0.0146 1.4061 0.0445 0.0783 333.0312 1.0069 274.7783
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
a -1.8465 -0.6502 -0.0164 0.7037 2.0664
bar.b[1,:][1] -2.7699 -0.9711 0.0028 0.9771 2.7813
bar.b[1,:][2] -2.5305 -0.9451 0.0408 0.9232 2.7366
bar.b[1,:][3] -2.6779 -0.9212 0.0311 1.0146 2.7213
bar.b[2,1:2][1] -1.8999 -0.6486 0.0397 0.7054 2.1453
bar.b[2,1:2][2] -1.9572 -0.7017 -0.0031 0.6454 1.8877
bar.b[2,3:3][1] -2.6416 -0.8997 -0.0014 0.9453 3.0213
bar.c[3:-1:1][1] -2.8470 -0.9299 -0.0128 0.9509 2.8362
bar.c[3:-1:1][2] -2.8268 -0.9619 0.0671 0.9812 2.7441
bar.c[3:-1:1][3] -2.7313 -0.9267 0.0001 0.9474 2.8913
If we call from_mcmcchains on this, we get an uninformative error.
Ideally we would like to get an InferenceData with bar.b and bar.c as variables. However, since the modeler can arbitrarily index and reindex and call getproperty to make arbitrarily complicated types, always doing the right thing is probably not possible. Also, MCMCChains's own machinery for combining flattened parameters into parameter arrays doesn't do a great job here.
For the short term then, I think it makes the most sense to raise an informative error of splitting by brackets produces anything more complicated than a tuple of integer indices. If users find this constraining and open issues, we can discuss supporting slightly more complicated indexing syntaxes.
Was there ever a resolution on this? I think I'm hitting an issue with a project I'm working on.
No, we haven't implemented a solution for this (mostly because no one else reported this being an issue). But now that LKJCholesky draws will appear in Chains with a syntax like F.L[i, j], which I think would hit this issue, it's a bit higher priority.
The long-term goal was to make InferenceData a chaintype that could be returned by sample instead of Chains, so we could maybe avoid Chains's flattening altogether. https://github.com/TuringLang/DynamicPPL.jl/issues/464 But that's a much bigger project I haven't had the focused time to finish.
Can you share what your Chains object looks like? I might be able to come up with a workaround a quick patch for that case.
Hey @sethaxen
so a slice of an example fast run gives names like this (yes I know the convergence is horrible here ;-) ):
Summary Statistics
parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
latent.latent_init[1] -0.0390 0.0928 0.0246 14.7598 65.6111 1.2054 0.0129
latent.σ_AR 0.4779 0.2497 0.1204 4.8553 11.2507 2.5235 0.0042
latent.ar_init[1] 0.0588 0.1163 0.0403 8.8253 15.4617 1.3860 0.0077
latent.damp_AR[1] 0.4611 0.1653 0.0787 4.8405 11.6086 2.5636 0.0042
latent.ϵ_t[1] -0.0325 0.2224 0.0860 6.1744 27.2663 1.7365 0.0054
latent.ϵ_t[2] -0.6047 0.6202 0.2991 4.6042 11.6951 3.0378 0.0040
init_incidence 4.6052 0.0000 0.0000 2131.9378 1288.5933 1.0015 1.8563
obs.std 0.2493 0.3386 0.1667 4.8690 11.5049 2.4899 0.0042
obs.ϵ_t[1] 0.1988 0.4118 0.1979 5.6882 15.6413 1.9164 0.0050
Okay, I think we can support this by removing support for the old . syntax for separating indices. Only Stan sample/CmdStan ever used that, and now they have native InferenceData support. I'll put together a PR.
Okay, I think we can support this by removing support for the old
.syntax for separating indices. Only Stan sample/CmdStan ever used that, and now they have native InferenceData support. I'll put together a PR.
Wow! Quick work. Thanks.
@SamuelBrand1, with the just-released ArviZ v0.11.0, you shouldn't have any issues with your example model, but let me know if you run into any. (e.g. instead you would get the variable names a, bar.b[1, :], bar.b[2, 1:2], etc.)
Note that many of the fancy indexing cases in https://github.com/arviz-devs/ArviZ.jl/issues/211#issue-1331742087 will now no longer error but will instead just not be processed into multi-dimensional arrays.
Thanks so much, I'll give this a whirl now
It works for our use case! Thanks so much for the speedy work here.
No problem!