pymc icon indicating copy to clipboard operation
pymc copied to clipboard

BUG: Option to disable jittering in `pymc.sampling.jax.sample_numpyro_nuts`

Open digicosmos86 opened this issue 1 year ago • 0 comments

Describe the issue:

In pymc.sampling.jax.sample_numpyro_nuts, a jitter is always applied to the initial values through function _get_batched_jittered_initial_points():

https://github.com/pymc-devs/pymc/blob/904a0eaac216732bc358dd91680cd428d95704f0/pymc/sampling/jax.py#L662-L668

This function actually accepts a jitter argument that allows jittering to be disabled. It would be really helpful that an argument like jitter becomes an argument of sample_numpyro_nuts, which allows users to disable jittering during initialization. The fact that the jitter is always set to True means that in case one wants greater control of the chain, it's not possible.

Reproduceable code example:

so if there is a top-level `jitter` option, the function call could look like this:


init_params = _get_batched_jittered_initial_points(
        model=model,
        chains=chains,
        initvals=initvals,
        random_seed=random_seed,
        jitter=jitter, # add this option here
)

Error message:

No response

PyMC version information:

pymc=5.10.2

Context for the issue:

No response

digicosmos86 avatar Dec 14 '23 18:12 digicosmos86