blackjax
blackjax copied to clipboard
Expose the adaptation as kernel
The window adaptation is currently only available via a function that implements and jit-compiles the loop. While this is convenient and should be kept, we should also expose the window adaptation as a kernel (and also the slow and fast updates for consitency).
Currently we have the following runtime:
import blackjax
adapt = blackjax.window_adaptation(blackjax.nuts, logprob_fn, num_warmup_steps)
state, kernel, info = adapt.run(rng_key, position)
While we wish to keep this design, we would like to allow separate chain and warmup state updates:
state = blackjax.nuts.init(position, logprob_fn)
warmup_state = blackjax.window_adaptation.init(num_warmup_steps)
new_state, info = blackjax.nuts.step(
rng_key,
state,
logprob_fn,
warmup_state.step_size,
warmup_state.inverse_mass_matrix
)
new_warmup_state, warmup_info = blackjax.window_adaptation.update(
warmup_state, new_state, info
)