DynamicHMC.jl icon indicating copy to clipboard operation
DynamicHMC.jl copied to clipboard

Timing warmup and sampling

Open sethaxen opened this issue 2 years ago • 2 comments

I would like to be able to separately time warmup and post-warmup sampling. A stretch goal would be the ability to time each individual stage of the warm-up. Is this possible with just API functions?

sethaxen avatar Feb 03 '23 14:02 sethaxen

The API is kind of semi-exposed, not official, but has been stable for a long time. See here. It is documented, sort of, see also mcmc_keep_warmup.

Here is an MWE:

using DynamicHMC, LogDensityTestSuite, ForwardDiff, LogDensityProblemsAD, Random

ℓ = StandardMultivariateNormal(5)
∇ℓ = ADgradient(:ForwardDiff, ℓ)
rng = Random.GLOBAL_RNG

wu = DynamicHMC.default_warmup_stages()

function extract_initialization(state)
    (; Q, κ, ϵ) = state.final_warmup_state
    (; q = Q.q, κ, ϵ)
end

state1 = DynamicHMC.mcmc_keep_warmup(rng, ∇ℓ, 0; warmup_stages = wu[1:1])
state2 = DynamicHMC.mcmc_keep_warmup(rng, ∇ℓ, 0; warmup_stages = wu[2:2],
                                     initialization = extract_initialization(state1))
state3 = DynamicHMC.mcmc_keep_warmup(rng, ∇ℓ, 0; warmup_stages = wu[3:3],
                                     initialization = extract_initialization(state2))
# just keep doing this, and run the last stage with as many samples as you need

Please keep the issue open even if this answers your question, I would like to expose this part of the API; occasionally I use it too.

tpapp avatar Feb 03 '23 16:02 tpapp

Thanks! This seems to work well! Here is how I separately time the entire warm-up phase and the sampling phase:

function extract_initialization(state)
    (; Q, κ, ϵ) = state
    return (; q=Q.q, κ, ϵ)
end

function dhmc_warmup(
    rng::Random.AbstractRNG,
    ℓ;
    initialization=(),
    warmup_stages=DynamicHMC.default_warmup_stages(),
    kwargs...,
)
    initialization_final = foldl(warmup_stages; init=initialization) do init, stage
        result = DynamicHMC.mcmc_keep_warmup(
            rng, ℓ, 0; warmup_stages=(stage,), initialization=init, kwargs...
        )
        return extract_initialization(result.final_warmup_state)
    end
    return initialization_final
end

function dhmc_sample(rng::Random.AbstractRNG, ℓ, ndraws; initialization, kwargs...)
    return DynamicHMC.mcmc_with_warmup(
        rng, ℓ, ndraws; warmup_stages=(), initialization, kwargs...
    )
end

Then I call each of these functions with @timed

sethaxen avatar Feb 07 '23 20:02 sethaxen