pymc
pymc copied to clipboard
BUG: Option to disable jittering in `pymc.sampling.jax.sample_numpyro_nuts`
Describe the issue:
In pymc.sampling.jax.sample_numpyro_nuts
, a jitter is always applied to the initial values through function _get_batched_jittered_initial_points()
:
https://github.com/pymc-devs/pymc/blob/904a0eaac216732bc358dd91680cd428d95704f0/pymc/sampling/jax.py#L662-L668
This function actually accepts a jitter
argument that allows jittering to be disabled. It would be really helpful that an argument like jitter
becomes an argument of sample_numpyro_nuts
, which allows users to disable jittering during initialization. The fact that the jitter
is always set to True
means that in case one wants greater control of the chain, it's not possible.
Reproduceable code example:
so if there is a top-level `jitter` option, the function call could look like this:
init_params = _get_batched_jittered_initial_points(
model=model,
chains=chains,
initvals=initvals,
random_seed=random_seed,
jitter=jitter, # add this option here
)
Error message:
No response
PyMC version information:
pymc=5.10.2
Context for the issue:
No response