arviz icon indicating copy to clipboard operation
arviz copied to clipboard

`arviz.from_numpyro` seems to ignore `log_likelihood` argument

Open lumip opened this issue 2 years ago • 2 comments
trafficstars

Describe the bug The arviz.from_numpyro method has an argument log_likelihood which seems to enable the user to provide a dictionary(?) of likelihoods for posterior samples. However, it seems to be entirely ignored in favour of likelihoods being computed from the provided model instead. (However, I may be using it wrongly as there appears to be no documentation of it; if so, please point that out to me).

To Reproduce Consider the following code snippet:

import numpyro.distributions as dists
import numpyro
import jax
import jax.numpy as jnp
import numpy as np
import arviz as az

def model(xs, ys):
    mu = numpyro.sample("mu", dists.Normal(0., 1.))
    with numpyro.plate("xs_obs", len(xs) if xs is not None else 1):
        xs_dist = dists.Normal(mu, 1.)
        xs = numpyro.sample("xs", xs_dist, obs=xs)
    with numpyro.plate("ys_obs", len(ys) if ys is not None else 1):
        ys_dist = dists.Normal(mu, 4.)
        ys = numpyro.sample("ys", ys_dist, obs=ys)

    ll = jnp.concatenate([xs_dist.log_prob(xs), ys_dist.log_prob(ys)])
    numpyro.deterministic("ll", ll)

xs = np.random.randn(1000) + 5.
ys = np.random.randn(1000) * 4. + 5.

nuts = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(nuts, num_chains=4, num_warmup=100, num_samples=2000)
mcmc.run(jax.random.PRNGKey(782385), xs, ys)

samples = mcmc.get_samples(group_by_chain=True)
ll = { 'xs_ys': samples['ll'] }

idata = az.from_numpyro(mcmc, log_likelihood=ll)
print(idata['log_likelihood'])
assert list(idata['log_likelihood'].data_vars) == ["xs_ys"]

This prints

<xarray.Dataset>
Dimensions:   (chain: 4, draw: 2000, xs_dim_0: 1000, ys_dim_0: 1000)
Coordinates:
  * chain     (chain) int64 0 1 2 3
  * draw      (draw) int64 0 1 2 3 4 5 6 ... 1993 1994 1995 1996 1997 1998 1999
  * xs_dim_0  (xs_dim_0) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
  * ys_dim_0  (ys_dim_0) int64 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
Data variables:
    xs        (chain, draw, xs_dim_0) float32 -2.074 -1.25 ... -1.283 -1.435
    ys        (chain, draw, ys_dim_0) float32 -4.216 -2.825 ... -5.277 -2.749
Attributes:
    created_at:                 2023-01-21T14:29:16.129290
    arviz_version:              0.14.0
    inference_library:          numpyro
    inference_library_version:  0.10.1

Expected behavior idata.log_likelihood should have a single data variable xs_ys with the contents of samples['ll'], i.e., arviz.from_numpyro should use the values provided via the log_likelihood instead of re-computing the log-likelihoods from the model.

Additional context

arviz==0.14.0
numpyro==0.10.1
jax==0.4.1
jaxlib==0.4.1

lumip avatar Jan 21 '23 14:01 lumip

The log_likelihood parameter can only be a boolean (True or False, or None if you want the value in rcParams to be used). If you want custom values in the log likelihood group that differ from what is defined by the model you should either modify those after calling from_numpyro or use loglikelihood=False and then InferenceData.add_groups to create the group later on (as done for example in https://python.arviz.org/en/stable/user_guide/numpyro_refitting_xr_lik.html). Hope it helps.

The docstring needs updating though to include the log_likelihood parameter, its type info and description. Do you want to create a PR for this?

OriolAbril avatar Jan 21 '23 18:01 OriolAbril

Alright, thanks for pointing that out. I will work around it like you suggested with the add_groups after creating the InferenceData object for now. I'm not sure I have the time currently to work on PRs, but I can see if I get around to it some time soon.

lumip avatar Jan 23 '23 18:01 lumip