pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Use same API for defining internal and external Nuts kwargs

Open ricardoV94 opened this issue 2 years ago • 2 comments

Description

User on discourse reported:

How can I set the maximum tree depth for the NUTS method from the numpyro library? The way described in the test file test_mcmc_external.py doesn’t work:

import pymc as pm
import numpy as np

with pm.Model():
        a = pm.Normal("a")
        idata = pm.sample(nuts_sampler = "numpyro",
                          target_accept = 0.99,
                          nuts = {"max_treedepth": 1},
                          random_seed = 1410)

print(np.max(idata.sample_stats.tree_depth))
# <xarray.DataArray 'tree_depth' ()>
# array(4)

and specifying something via the nuts_kwargs argument throws ValueError: Unused step method arguments: {'nuts_kwargs'}.

I don't know if nuts should be converted to nuts_kwargs, but even if a user were to pass nuts_kwargs to sample, those wouldn't make it to the sample_numpyro_nuts function because we drop arbitrary kwargs passed here:

https://github.com/pymc-devs/pymc/blob/261862d778910a09c5b61edcc66958519a86815e/pymc/sampling/mcmc.py#L252

ricardoV94 avatar Jun 07 '23 07:06 ricardoV94

I could have sworn the the NUTS arguments dict was called nuts_sampler_kwargs.

Yes, it is:

    idata_kwargs: Optional[Dict],
    nuts_sampler_kwargs: Optional[Dict],
    **kwargs,

fonnesbeck avatar Jun 09 '23 19:06 fonnesbeck

Shouldn't we use the same API for passing kwargs to the PyMC nuts?

pm.sample(..., {"nuts": ...})`

ricardoV94 avatar Feb 06 '24 10:02 ricardoV94