arviz
arviz copied to clipboard
select individual levels from the Dimensions to plot pystan
Short Description
I want to select individual levels from the Dimensions to plot because plotting all of the levels of a variable is slow and the plot uninterpretable.
Code Example or link
I am trying to reproduce the PyStan example here showing the use of multilevel modelling.
The code and extraction are below:
varying_intercept = """
data {
int<lower=0> J; // the number of counties
int<lower=0> N; // the number of observations
int<lower=1,upper=J> county[N]; // the county for each observation
vector[N] x; // predictor/regressor (floor/basement)
vector[N] y; // the output variable (log radon levels)
}
parameters {
vector[J] a; // the random intercept
real b; // the fixed coefficient (FIXED EFFECT)
real mu_a; // mean of the population of counties (CONSTANT FOR POPULATION)
real<lower=0,upper=100> sigma_a; // the variance of the counties (CONSTANT FOR POPULATION)
real<lower=0,upper=100> sigma_y; // the variance of the observations (CONSTANT FOR POPULATION)
}
transformed parameters {
vector[N] y_hat; // estimated log radon level for each datapoint
for (i in 1:N) // for each datapoint
y_hat[i] <- a[county[i]] + x[i] * b; // estimate the mean of the log radon as a simple linear regression
}
model {
sigma_a ~ uniform(0, 100); // variation between the counties
a ~ normal (mu_a, sigma_a); // the intercept varying (RANDOM EFFECT)
b ~ normal (0, 1); // the coefficient
sigma_y ~ uniform(0, 100); // the sampling variation of the log-radon
y ~ normal(y_hat, sigma_y); // model the log radon levels
}
"""
varying_intercept_data = {'N': len(log_radon),
'J': len(n_county),
'county': county+1, # Stan counts starting at 1
'x': floor_measure,
'y': log_radon}
varying_intercept_fit = pystan.stan(model_code=varying_intercept, data=varying_intercept_data, iter=1000, chains=2)
I then extract the data to ArViz
fit = varying_intercept_fit
data = az.from_pystan(posterior=fit,
posterior_predictive='y_hat',
observed_data=['y'],
coords={'county': n_county},
dims={'a': ['county']}) #, 'y': ['county'], 'log_lik': ['county'], 'y_hat': ['county'], 'theta_tilde': ['county']})
data
Out[]:
Inference data with groups:
> posterior
> sample_stats
> posterior_predictive
> observed_data
I want to make a plot of only a few of the counties (the model levels). The following takes an age to run because it is plotting ALL counties traces, but I want to select them.
az.traceplot(data)
I found this help here:
az.plot_trace(data, var_names='a', coords={'county': range(0, 5)});
az.plot_forest(data.posterior.sel(county=range(0, 5)), var_names='a');
az.plot_parallel(data, var_names='a', coords={'county': range(0, 5)});
az.plot_posterior(data, var_names='a', coords={'county': range(0, 5)});
But I get an error:
---------------------------------------------------------------------------
InvalidIndexError Traceback (most recent call last)
<ipython-input-35-46d226304076> in <module>
1 # select only some of the levels to plot!!!
2 # https://discourse.pymc.io/t/best-way-to-plot-and-do-ppc-with-variable-that-has-too-many-levels/2276/3
----> 3 az.plot_trace(data, var_names='a', coords={'county': range(0, 5)});
4 # az.plot_forest(data.posterior.sel(county=range(0, 5)), var_names='a');
5 # az.plot_parallel(data, var_names='a', coords={'county': range(0, 5)});
~/miniconda3/envs/stan/lib/python3.7/site-packages/arviz/plots/traceplot.py in plot_trace(data, var_names, coords, divergences, figsize, textsize, lines, combined, kde_kwargs, hist_kwargs, trace_kwargs)
105 lines = ()
106
--> 107 plotters = list(xarray_var_iter(get_coords(data, coords), var_names=var_names, combined=True))
108
109 if figsize is None:
~/miniconda3/envs/stan/lib/python3.7/site-packages/arviz/plots/plot_utils.py in get_coords(data, coords)
322 """
323 try:
--> 324 return data.sel(**coords)
325
326 except ValueError:
~/miniconda3/envs/stan/lib/python3.7/site-packages/xarray/core/dataset.py in sel(self, indexers, method, tolerance, drop, **indexers_kwargs)
1608 indexers = either_dict_or_kwargs(indexers, indexers_kwargs, 'sel')
1609 pos_indexers, new_indexes = remap_label_indexers(
-> 1610 self, indexers=indexers, method=method, tolerance=tolerance)
1611 result = self.isel(indexers=pos_indexers, drop=drop)
1612 return result._replace_indexes(new_indexes)
~/miniconda3/envs/stan/lib/python3.7/site-packages/xarray/core/coordinates.py in remap_label_indexers(obj, indexers, method, tolerance, **indexers_kwargs)
353
354 pos_indexers, new_indexes = indexing.remap_label_indexers(
--> 355 obj, v_indexers, method=method, tolerance=tolerance
356 )
357 # attach indexer's coordinate to pos_indexers
~/miniconda3/envs/stan/lib/python3.7/site-packages/xarray/core/indexing.py in remap_label_indexers(data_obj, indexers, method, tolerance)
256 else:
257 idxr, new_idx = convert_label_indexer(index, label,
--> 258 dim, method, tolerance)
259 pos_indexers[dim] = idxr
260 if new_idx is not None:
~/miniconda3/envs/stan/lib/python3.7/site-packages/xarray/core/indexing.py in convert_label_indexer(index, label, index_name, method, tolerance)
192 raise ValueError('Vectorized selection is not available along '
193 'MultiIndex variable: ' + index_name)
--> 194 indexer = get_indexer_nd(index, label, method, tolerance)
195 if np.any(indexer < 0):
196 raise KeyError('not all values found in index %r'
~/miniconda3/envs/stan/lib/python3.7/site-packages/xarray/core/indexing.py in get_indexer_nd(index, labels, method, tolerance)
120
121 flat_labels = np.ravel(labels)
--> 122 flat_indexer = index.get_indexer(flat_labels, **kwargs)
123 indexer = flat_indexer.reshape(labels.shape)
124 return indexer
~/miniconda3/envs/stan/lib/python3.7/site-packages/pandas/core/indexes/base.py in get_indexer(self, target, method, limit, tolerance)
2737
2738 if not self.is_unique:
-> 2739 raise InvalidIndexError('Reindexing only valid with uniquely'
2740 ' valued Index objects')
2741
InvalidIndexError: Reindexing only valid with uniquely valued Index objects
Also include the ArviZ version and version of any other relevant packages.
Arviz Version: 0.3.2
numpy Version: 1.15.0
pandas Version: 0.24.1
Relevant documentation or public examples
https://mc-stan.org/users/documentation/case-studies/radon.html
We need a more flexible selection for the variables, but not sure what would be our best option for the API.
@OriolAbril did we have any update on this issue?
This issue is actually due to indexing properties of xarray. Simple ArviZ unrelated example below:
import xarray as xr
import numpy as np
data = xr.DataArray(
data=np.random.random(size=(4,100,8)),
dims=("chain", "draw", "dim1"),
coords={"chain": range(4), "draw": range(100), "dim1": np.random.choice([0,1,2], size=8)}
)
print(data)
# output
# <xarray.DataArray (chain: 4, draw: 100, dim1: 8)>
# array([[[0.828962, 0.514844, ..., 0.180102, 0.365011],
# ...
# [0.548044, 0.621308, ..., 0.373455, 0.586788]]])
# Coordinates:
# * chain (chain) int64 0 1 2 3
# * draw (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99
# * dim1 (dim1) int64 1 2 1 1 2 2 1 0
print(data.sel(dim1=1))
# output
# <xarray.DataArray (chain: 4, draw: 100, dim1: 4)>
# array([[[0.828962, 0.697391, 0.384503, 0.180102],
# ...
# [0.548044, 0.793319, 0.735403, 0.373455]]])
# Coordinates:
# * chain (chain) int64 0 1 2 3
# * draw (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99
# * dim1 (dim1) int64 1 1 1 1
# but
data.sel(dim1=[1,2])
# output
# InvalidIndexError: Reindexing only valid with uniquely valued Index objects
To actually select a subset of a DataArray or Dataset based on a coordinate with repeated index values, where
must be used.
data.where(data.dim1.isin((1,2)), drop=True)
# drop is False by default, and it converts values not fulfilling to NaN, which is not our goal
# output
# <xarray.DataArray (chain: 4, draw: 100, dim1: 7)>
# array([[[0.828962, 0.514844, ..., 0.042839, 0.180102],
# ...
# [0.548044, 0.621308, ..., 0.110815, 0.373455]]])
# Coordinates:
# * chain (chain) int64 0 1 2 3
# * draw (draw) int64 0 1 2 3 4 5 6 7 8 9 ... 90 91 92 93 94 95 96 97 98 99
# * dim1 (dim1) int64 1 2 1 1 2 2 1
We could discuss on how to implement this into ArviZ. For now it must be done by the user before calling ArviZ functions. I guess that in your case it would be something like:
az.plot_forest(
data.posterior.where(data.posterior.county.isin(range(0, 5)), drop=True),
var_names='a'
);
@OriolAbril did we have .where described somewhere in the docs? It is really powerful function.