DynamicPPL.jl
DynamicPPL.jl copied to clipboard
Add get_merged_chains function
I have tried to add some small helper functions that allows one to merge chains of a model with the generated quantities.
For more information regarding the background of this, please see this Discourse thread.
I am very new to the internals of DynamicPPL.jl, so this is most likely not the optimal implementation, but I was encouraged on Slack to make this PR anyways. Hope this can help in some way.
Thanks for your reply, @devmotion!
It seems like my code should maybe be taken somewhere else. Do you have any suggestions? Turing.jl? MCMCChains.jl?
Regarding your suggestions, I have a couple of comments:
- I thought that small functions were good, both from a compiler standpoint, but also just for readability. I have changed that now.
- Is
compute_
e.g. better? Or what are you suggesting? - I have changed this as per your suggestion.
- I have no experience with testing in Julia, so I'll have to add that sometimes later.
My updated code is below, now as a single module:
module MergeChains
export merge
import Turing
function get_generated_quantities(model::Turing.Model, chains::Turing.Chains)
chains_params = Turing.MCMCChains.get_sections(chains, :parameters)
generated_quantities = Turing.generated_quantities(model, chains_params)
return generated_quantities
end
function generated_quantities_to_chain(
generated_quantities::AbstractMatrix,
chains::Turing.Chains,
variable::Union{Symbol,String},
)
# The number of dimensions (K) for the specific variable
K = length(first(generated_quantities)[variable])
N_samples = length(chains)
N_chains = length(Turing.chains(chains))
matrix = zeros(N_samples, K, N_chains)
for chain = 1:N_chains
for (i, xi) in enumerate(generated_quantities[:, chain])
matrix[i, :, chain] .= xi[variable]
end
end
if K == 1
chain_names = [Symbol("$variable")]
else
chain_names = [Symbol("$variable[$i]") for i = 1:K]
end
generated_chain = Turing.Chains(matrix, chain_names, info = chains.info)
return generated_chain
end
function generated_quantities_to_chain(
generated_quantities::AbstractMatrix,
chains::Turing.Chains,
variables::Tuple,
)
func = variable -> generated_quantities_to_chain(generated_quantities, chains, variable)
return hcat(func.(variables)...)
end
function merge_generated_chains(chains::Turing.Chains, generated_chains::Turing.Chains)
return hcat(chains, Turing.setrange(generated_chains, range(chains)))
end
function merge(model::Turing.Model, chains::Turing.Chains)
generated_quantities = get_generated_quantities(model, chains)
if generated_quantities isa Matrix{Nothing}
return chains
end
variables = generated_quantities |> first |> keys
generated_chains =
generated_quantities_to_chain(generated_quantities, chains, variables)
chains_merged = merge_generated_chains(chains, generated_chains)
return chains_merged
end
end # module
Codecov Report
Patch coverage has no change and project coverage change: -1.17
:warning:
Comparison is base (
e6dd4ef
) 76.40% compared to head (67e433a
) 75.24%.
Additional details and impacted files
@@ Coverage Diff @@
## master #409 +/- ##
==========================================
- Coverage 76.40% 75.24% -1.17%
==========================================
Files 21 22 +1
Lines 2522 2561 +39
==========================================
Hits 1927 1927
- Misses 595 634 +39
Impacted Files | Coverage Δ | |
---|---|---|
src/merge_generated_quantities.jl | 0.00% <0.00%> (ø) |
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
Closed in favour of https://github.com/TuringLang/DynamicPPL.jl/pull/594