MCMCTempering.jl icon indicating copy to clipboard operation
MCMCTempering.jl copied to clipboard

Using MCMCTempering when I'm overloading `step`

Open pdeffebach opened this issue 1 year ago • 1 comments

Continuing a discussion on Slack. I am a little confused about what I need to overload in order to get an existing custom Metropolis-Hastings algorithm to work with Parallel Tempering.

The core issue is that my problem is discrete. We are essentially trying to solve the Traveling Salesman problem, so you can think of our state as a vector of booleans, representing edges to upgrade or not upgrade. Our sampler does some process of adding or subtracting edges to upgrade. We are looking for the best set of edges to upgrade according to some objective function W by drawing from the distribution

exp(beta * W(x))

where x evolves according to a Markov process. Just to be clear: W is deterministic, it's just an objective function. It's not like we have some underlying distribution pi(x) and we want to make better draws from pi(x) by drawing from pi(x)^beta.

In other words, there is no characterization the distribution we are sampling from without beta. For example, even to implement a vanilla Metropolis-Hastings algorithm to get the best x, we still need to choose some value of beta.

As far as I can tell, this means we need a custom step function which incorporates the annealing parameter, defined below. (This could be wrong).

Step implementation
function AbstractMCMC.step(
    rng::Random.AbstractRNG,
    model::PolicyObjective,
    sampler::PolicySampler,
    transition_prev::PolicyTransition;
    kwargs...
)
    invtemp = sampler.invtemp

    # Generate a new proposal.
    candidate = propose(rng, sampler, model, transition_prev)

    # Calculate the log acceptance probability and the log density of the candidate.
    objval_candidate = obj(model, candidate)

    logα = invtemp * (objval_candidate - obj(model, transition_prev))

    # Decide whether to return the previous params or the new one.
    trans = if -Random.randexp(rng) < logα
        transition(sampler, model, candidate, objval_candidate, true)
    else
        params = transition_prev.params
        objval = transition_prev.obj
        PolicyTransition(params, objval, false)
    end

    return trans, trans
end

As you can see, we store invtemp, the annealing parameter, in our Sampler struct. I was expecting something along these lines in MCMCTempering implementations, but I don't see it. Instead we have

 TemperedSampler(sampler, inverse_temperatures)

and I don't understand how the inverse_temperatures get attached to individual samplers. In MCMCTempering.jl, I see that beta gets attached to the model, not the sampler.

Code from MCMCTempering
make_tempered_model(sampler, model, beta) = make_tempered_model(model, beta)
function make_tempered_model(model, beta)
    if !implements_logdensity(model)
        error("`make_tempered_model` is not implemented for $(typeof(model)); either implement explicitly, or implement the LogDensityProblems.jl interface for `model`")
    end

    return TemperedLogDensityProblem(model, beta)
end
function make_tempered_model(model::AbstractMCMC.LogDensityModel, beta)
    return AbstractMCMC.LogDensityModel(TemperedLogDensityProblem(model.logdensity, beta))
end

So overall I think I have two questions

  1. How to implement step to take advantage of a sampler / model with an attached beta
  2. If I define my objective function to comply with the LogDensity interface, will everything just work? Is this true even when I need a non-standard (discrete) sampler?

A little confused on how to get MCMCTempering working.

pdeffebach avatar Mar 07 '25 21:03 pdeffebach

I've successfully gotten it working (I think), but I had to commit type piracy on MCMCTempering.compute_logdensities, so I think I did not do it correctly.

Can you give me some information on if the implementation below looks correct at all?

Implementation
MCMCTempering.getparams_and_logprob(t::AbstractPolicyTransition) = t.params, t.obj
function MCMCTempering.setparams_and_logprob!!(t::AbstractPolicyTransition, params, obj)
    return PolicyTransition(params, obj, false)
end

function MCMCTempering.compute_logdensities(
    model::MCMCTempering.TemperedLogDensityProblem,
    state,
    state_other,
)
    return (
        MCMCTempering.getlogprob(model.logdensity, state),                        # This we can just extract.
        MCMCTempering.logdensity(model.logdensity, MCMCTempering.getparams(model, state_other)) # While this we need to compute.
    )
end

function MCMCTempering.compute_logdensities(
    model::MCMCTempering.TemperedLogDensityProblem,
    model_other::MCMCTempering.TemperedLogDensityProblem,
    state,
    state_other,
)
    return MCMCTempering.compute_logdensities(model, state, state_other)
end

function AbstractMCMC.step(
    rng::Random.AbstractRNG,
    model::MCMCTempering.TemperedLogDensityProblem{<:TemperedPolicyObjective, L},
    sampler::PolicySampler;
    initial_params=nothing,
    kwargs...
) where {L}

    params = initial_params === nothing ? propose(rng, sampler, model.logdensity) : initial_params
    trans = transition(sampler, model.logdensity, params, false)
    return trans, trans
end

function AbstractMCMC.step(
    rng::Random.AbstractRNG,
    model::MCMCTempering.TemperedLogDensityProblem{<:TemperedPolicyObjective, L},
    sampler::PolicySampler,
    transition_prev::PolicyTransition;
    kwargs...
) where {L}
    # Generate a new proposal.
    candidate = propose(rng, sampler, model.logdensity, transition_prev)
    # Calculate the log acceptance probability and the log density of the candidate.
    objval_candidate = LogDensityProblems.logdensity(model, candidate)

    logα = (objval_candidate - LogDensityProblems.logdensity(model, transition_prev))

    # Decide whether to return the previous params or the new one.
    trans = if -Random.randexp(rng) < logα
        transition(sampler, model.logdensity, candidate, objval_candidate, true)
    else
        params = transition_prev.params
        objval = transition_prev.obj
        PolicyTransition(params, objval, false)
    end

    return trans, trans
end

pdeffebach avatar Mar 11 '25 20:03 pdeffebach