DynamicPPL.jl icon indicating copy to clipboard operation
DynamicPPL.jl copied to clipboard

Add get_merged_chains function

Open ChristianMichelsen opened this issue 2 years ago • 2 comments

I have tried to add some small helper functions that allows one to merge chains of a model with the generated quantities.

For more information regarding the background of this, please see this Discourse thread.

I am very new to the internals of DynamicPPL.jl, so this is most likely not the optimal implementation, but I was encouraged on Slack to make this PR anyways. Hope this can help in some way.

ChristianMichelsen avatar May 24 '22 13:05 ChristianMichelsen

Thanks for your reply, @devmotion!

It seems like my code should maybe be taken somewhere else. Do you have any suggestions? Turing.jl? MCMCChains.jl?

Regarding your suggestions, I have a couple of comments:

  • I thought that small functions were good, both from a compiler standpoint, but also just for readability. I have changed that now.
  • Is compute_ e.g. better? Or what are you suggesting?
  • I have changed this as per your suggestion.
  • I have no experience with testing in Julia, so I'll have to add that sometimes later.

My updated code is below, now as a single module:

module MergeChains

export merge


import Turing


function get_generated_quantities(model::Turing.Model, chains::Turing.Chains)
    chains_params = Turing.MCMCChains.get_sections(chains, :parameters)
    generated_quantities = Turing.generated_quantities(model, chains_params)
    return generated_quantities
end


function generated_quantities_to_chain(
    generated_quantities::AbstractMatrix,
    chains::Turing.Chains,
    variable::Union{Symbol,String},
)

    # The number of dimensions (K) for the specific variable
    K = length(first(generated_quantities)[variable])
    N_samples = length(chains)
    N_chains = length(Turing.chains(chains))

    matrix = zeros(N_samples, K, N_chains)
    for chain = 1:N_chains
        for (i, xi) in enumerate(generated_quantities[:, chain])
            matrix[i, :, chain] .= xi[variable]
        end
    end

    if K == 1
        chain_names = [Symbol("$variable")]
    else
        chain_names = [Symbol("$variable[$i]") for i = 1:K]
    end
    generated_chain = Turing.Chains(matrix, chain_names, info = chains.info)

    return generated_chain
end


function generated_quantities_to_chain(
    generated_quantities::AbstractMatrix,
    chains::Turing.Chains,
    variables::Tuple,
)
    func = variable -> generated_quantities_to_chain(generated_quantities, chains, variable)
    return hcat(func.(variables)...)
end


function merge_generated_chains(chains::Turing.Chains, generated_chains::Turing.Chains)
    return hcat(chains, Turing.setrange(generated_chains, range(chains)))
end


function merge(model::Turing.Model, chains::Turing.Chains)

    generated_quantities = get_generated_quantities(model, chains)

    if generated_quantities isa Matrix{Nothing}
        return chains
    end

    variables = generated_quantities |> first |> keys
    generated_chains =
        generated_quantities_to_chain(generated_quantities, chains, variables)

    chains_merged = merge_generated_chains(chains, generated_chains)

    return chains_merged

end

end # module

ChristianMichelsen avatar Jun 27 '22 10:06 ChristianMichelsen

Codecov Report

Patch coverage has no change and project coverage change: -1.17 :warning:

Comparison is base (e6dd4ef) 76.40% compared to head (67e433a) 75.24%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #409      +/-   ##
==========================================
- Coverage   76.40%   75.24%   -1.17%     
==========================================
  Files          21       22       +1     
  Lines        2522     2561      +39     
==========================================
  Hits         1927     1927              
- Misses        595      634      +39     
Impacted Files Coverage Δ
src/merge_generated_quantities.jl 0.00% <0.00%> (ø)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

codecov[bot] avatar Jul 04 '23 21:07 codecov[bot]

Closed in favour of https://github.com/TuringLang/DynamicPPL.jl/pull/594

yebai avatar May 05 '24 15:05 yebai