blackjax icon indicating copy to clipboard operation
blackjax copied to clipboard

Expose the adaptation as kernel

Open rlouf opened this issue 3 years ago • 0 comments

The window adaptation is currently only available via a function that implements and jit-compiles the loop. While this is convenient and should be kept, we should also expose the window adaptation as a kernel (and also the slow and fast updates for consitency).

Currently we have the following runtime:

import blackjax

adapt = blackjax.window_adaptation(blackjax.nuts, logprob_fn, num_warmup_steps)
state, kernel, info = adapt.run(rng_key, position)

While we wish to keep this design, we would like to allow separate chain and warmup state updates:

state = blackjax.nuts.init(position, logprob_fn)
warmup_state = blackjax.window_adaptation.init(num_warmup_steps)

new_state, info = blackjax.nuts.step(
    rng_key,
    state,
    logprob_fn,
    warmup_state.step_size,
    warmup_state.inverse_mass_matrix
)
new_warmup_state, warmup_info = blackjax.window_adaptation.update(
    warmup_state, new_state, info
)

rlouf avatar Feb 08 '22 09:02 rlouf