blackjax icon indicating copy to clipboard operation
blackjax copied to clipboard

Explicitly pass the temperature parameter from (adaptive) tempered SMC to the MCMC kernel

Open maxhinne opened this issue 2 years ago • 2 comments

Current behavior

Currently, the (adaptive) tempered SMC kernel samples (as desired) from lmbda * loglikelihood + logprior, where loglikelihood and logprior are densities provided by the user. This temperature is then scaled from 0 to 1 to eventually sample from the target posterior. This works as intended when combined with other Blackjax kernels that take a (posterior) logdensity as argument, such as RMH, HMC and NUTS.

However, other kernels, such as elliptical_slice or mgrad_gaussian take only the loglikelihood as an argument, as here the prior is built into the model. The same typically applies for custom Gibbs kernels, especially if they describe a hierarchical model; we'd need to construct several target densities within the Gibbs kernel, some that require the temperature, others that do not. This makes it so SMC cannot be combined with the last-mentioned kernels (or it can, but this would not effectively temper the target distribution: we would always sample from the posterior, and use the tempering only in reweighing the SMC particles. This hampers exploration, the key benefit of SMC).

The relevant code is on lines 128 and 132 in https://github.com/blackjax-devs/blackjax/blob/main/blackjax/smc/tempered.py:

state = mcmc_init_fn(position, tempered_logposterior_fn)

and

new_state, info = mcmc_step_fn(
    rng_key, state, tempered_logposterior_fn, **mcmc_parameters
)

In both cases, the MCMC-within-SMC kernel is called with the tempered_logposterior_fn, in which lmbda has already been incorporated.

Desired behavior

Since states are NamedTuples which are immutable, but users have to define their own mcmc_init_fn anyway, lmbda could be passed as an argument to mcmc_init_fn, which leaves it up to the user to decide whether it is relevant to store this temperature in the MCMC state object. In our own use-case, we'd take this temperature from the MCMC state and use it to determine something like:

tempered_loglikelihood = lambda state: state.lmbda*loglikelihood_fn(state)

I hope this enhancement is possible, it would greatly ease our way of working with Blackjax! :-)

maxhinne avatar Feb 22 '23 16:02 maxhinne

I agree with the above. We have had several chats with @rlouf over the past on this specific point.

Food for thought:

The "only" way I see to do this is to abstract away the concept of the prior/loglikelihood into a form of mcmc_factory which would be user defined. It is however not clear how do this in a clean way given that the log-likelihood is needed to compute the importance weights at each tempering step.

There is also the fact that the log-likelihood is not allowed to change in the definition of the elliptical slice sampler or the Gaussian samplers. Being able to combine both would likely require some refactoring on this end too.

AdrienCorenflos avatar Feb 23 '23 07:02 AdrienCorenflos

An additional point to note: at the moment, at least the Gaussian sampler is parametrised in a way that reduces the need for computation:

https://github.com/blackjax-devs/blackjax/blob/6b76746e0352c71c2489b12d38b8457ffb1ab13e/blackjax/mcmc/marginal_latent_gaussian.py#L46

Changing the tempering parameter from under it without updating the U_grad_x of the state would result in an invalid algorithm. In practice, it is easy to do, but we can't expect the user (even an expert one!) to know low-level implementation details to this extent. I'm sure there are other examples of this in the library, where changing the target is very non-obvious.

AdrienCorenflos avatar Feb 23 '23 07:02 AdrienCorenflos