AbstractPPL.jl icon indicating copy to clipboard operation
AbstractPPL.jl copied to clipboard

Testing the interface with simulation-based calibration

Open sethaxen opened this issue 3 years ago • 3 comments

At @torfjelde's suggestion, I am testing the AbstractPPL interface to see what it is missing. Here the test case is simulation-based calibration. From the interface spec, I got to this implementation:

using AbstractPPL, AbstractMCMC

# returns collection of traces, where the variable names correspond to
# those produced by the generative model, except those in data_vars, and the
# corresponding values are ranks of the prior draws in the posterior
function calibrate(rng, model, sampler, data_vars; nreps=1_000, ndraws=100)
    joint_model = AbstractPPL.decondition(model) # decondition just in case
    ranks = map(1:nreps) do _
        step_rank(rng, joint_model, sampler, ndraws, data_vars)
    end
    return ranks
end
function step_rank(rng, joint_model, sampler, ndraws, data_vars)
    θ̃_ỹ, _ = AbstractMCMC.step(rng, joint_model) # NOTE: method does not exist
    ỹ, θ̃ = split_trace(θ̃_ỹ, data_vars)
    posterior_model = AbstractPPL.condition(joint_model, ỹ)
    θ = AbstractMCMC.sample(rng, posterior_model, sampler, ndraws)
    return rank_draw_in_sample(θ̃, θ)
end
function split_trace(draw, vars)
    # split draw into part whose variables names do not match vars
    # and part whose variable names do, handling indices in names correctly
    # e.g. if draw=@T(x=10, y[1]=10, y[2]=5) and vars=(:y,),
    # then this returns @T(y[1]=10, y[2]=5), @T(x=10,)
end
function rank_draw_in_sample(draw, sample)
    # compute element-wise rank of all variables in draw in sample.
    # i.e. if draw=@T(x[1]=10) and sample=[@T(x[1]=40), @T(x[1]=10), @T(x[1]=0)],
    # then this returns @T(x[1]=3)
end

Here is an example of how one might use this with Turing:

using Turing, Random
@model function model(y)
    μ ~ Normal(0, 1)
    σ ~ truncated(Normal(0, 1), 0, Inf)
    y ~ Normal(μ, σ)
end
rng = MersenneTwister(42)
calibrate(rng, model(1.5), NUTS(), (:y,))

What we are missing (so far):

  • A method like AbstractMCMC.step(rng, joint_model) to exactly sample from the joint prior and prior-predictive distribution.
  • functionality to manipulate traces, e.g. splitting a trace into two traces based on variable names
  • functionality to map over variable names and values of a trace, constructing a new trace with different values

sethaxen avatar Aug 03 '21 11:08 sethaxen

Additionally, and this is more of an AbstractMCMC comment, if sample could take an array of models and not just a model, then we could also use the MCMCThreads, MCMCDistributed, and MCMCSerial parallelization options.

sethaxen avatar Aug 03 '21 11:08 sethaxen

thanks, @sethaxen - there are really helpful! These features are also helpful for simulation-based inference algorithms, e.g. particle-based sampling algorithms (see, https://github.com/TuringLang/AdvancedPS.jl).

yebai avatar Aug 03 '21 16:08 yebai

Yes, actually, diagnosing issues with SBI for posterior inference is one of the use cases I have for SBC.

sethaxen avatar Aug 05 '21 10:08 sethaxen