pymc
pymc copied to clipboard
Use same API for defining internal and external Nuts kwargs
Description
User on discourse reported:
How can I set the maximum tree depth for the NUTS method from the numpyro library? The way described in the test file test_mcmc_external.py doesn’t work:
import pymc as pm import numpy as np with pm.Model(): a = pm.Normal("a") idata = pm.sample(nuts_sampler = "numpyro", target_accept = 0.99, nuts = {"max_treedepth": 1}, random_seed = 1410) print(np.max(idata.sample_stats.tree_depth)) # <xarray.DataArray 'tree_depth' ()> # array(4)and specifying something via the nuts_kwargs argument throws ValueError: Unused step method arguments: {'nuts_kwargs'}.
I don't know if nuts should be converted to nuts_kwargs, but even if a user were to pass nuts_kwargs to sample, those wouldn't make it to the sample_numpyro_nuts function because we drop arbitrary kwargs passed here:
https://github.com/pymc-devs/pymc/blob/261862d778910a09c5b61edcc66958519a86815e/pymc/sampling/mcmc.py#L252
I could have sworn the the NUTS arguments dict was called nuts_sampler_kwargs.
Yes, it is:
idata_kwargs: Optional[Dict],
nuts_sampler_kwargs: Optional[Dict],
**kwargs,
Shouldn't we use the same API for passing kwargs to the PyMC nuts?
pm.sample(..., {"nuts": ...})`