arviz
arviz copied to clipboard
`arviz.from_numpyro` seems to ignore `log_likelihood` argument
Describe the bug
The arviz.from_numpyro method has an argument log_likelihood which seems to enable the user to provide a dictionary(?) of likelihoods for posterior samples. However, it seems to be entirely ignored in favour of likelihoods being computed from the provided model instead. (However, I may be using it wrongly as there appears to be no documentation of it; if so, please point that out to me).
To Reproduce Consider the following code snippet:
import numpyro.distributions as dists
import numpyro
import jax
import jax.numpy as jnp
import numpy as np
import arviz as az
def model(xs, ys):
mu = numpyro.sample("mu", dists.Normal(0., 1.))
with numpyro.plate("xs_obs", len(xs) if xs is not None else 1):
xs_dist = dists.Normal(mu, 1.)
xs = numpyro.sample("xs", xs_dist, obs=xs)
with numpyro.plate("ys_obs", len(ys) if ys is not None else 1):
ys_dist = dists.Normal(mu, 4.)
ys = numpyro.sample("ys", ys_dist, obs=ys)
ll = jnp.concatenate([xs_dist.log_prob(xs), ys_dist.log_prob(ys)])
numpyro.deterministic("ll", ll)
xs = np.random.randn(1000) + 5.
ys = np.random.randn(1000) * 4. + 5.
nuts = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts, num_chains=4, num_warmup=100, num_samples=2000)
mcmc.run(jax.random.PRNGKey(782385), xs, ys)
samples = mcmc.get_samples(group_by_chain=True)
ll = { 'xs_ys': samples['ll'] }
idata = az.from_numpyro(mcmc, log_likelihood=ll)
print(idata['log_likelihood'])
assert list(idata['log_likelihood'].data_vars) == ["xs_ys"]
This prints
<xarray.Dataset>
Dimensions: (chain: 4, draw: 2000, xs_dim_0: 1000, ys_dim_0: 1000)
Coordinates:
* chain (chain) int64 0 1 2 3
* draw (draw) int64 0 1 2 3 4 5 6 ... 1993 1994 1995 1996 1997 1998 1999
* xs_dim_0 (xs_dim_0) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
* ys_dim_0 (ys_dim_0) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
Data variables:
xs (chain, draw, xs_dim_0) float32 -2.074 -1.25 ... -1.283 -1.435
ys (chain, draw, ys_dim_0) float32 -4.216 -2.825 ... -5.277 -2.749
Attributes:
created_at: 2023-01-21T14:29:16.129290
arviz_version: 0.14.0
inference_library: numpyro
inference_library_version: 0.10.1
Expected behavior
idata.log_likelihood should have a single data variable xs_ys with the contents of samples['ll'], i.e., arviz.from_numpyro should use the values provided via the log_likelihood instead of re-computing the log-likelihoods from the model.
Additional context
arviz==0.14.0
numpyro==0.10.1
jax==0.4.1
jaxlib==0.4.1
The log_likelihood parameter can only be a boolean (True or False, or None if you want the value in rcParams to be used). If you want custom values in the log likelihood group that differ from what is defined by the model you should either modify those after calling from_numpyro or use loglikelihood=False and then InferenceData.add_groups to create the group later on (as done for example in https://python.arviz.org/en/stable/user_guide/numpyro_refitting_xr_lik.html). Hope it helps.
The docstring needs updating though to include the log_likelihood parameter, its type info and description. Do you want to create a PR for this?
Alright, thanks for pointing that out. I will work around it like you suggested with the add_groups after creating the InferenceData object for now. I'm not sure I have the time currently to work on PRs, but I can see if I get around to it some time soon.