arviz
arviz copied to clipboard
Better API for obtaining posterior point estimates & more
As part of a notebook for pymc (https://github.com/pymc-devs/pymc-examples/pull/241, seen here) to support the addition of the Generalized Extreme Value distribution (https://github.com/pymc-devs/pymc/pull/5085), I had a few complexities processing results with the current API.
I'm raising this issue here to see if there is appetite for a PR along the proposed lines here.
- This code snippet is used to compare a posterior result with the maximum likelihood estimate from the reference book:
_, vals = az.sel_utils.xarray_to_ndarray(trace["posterior"], var_names=["μ", "σ", "ξ"])
mle = [az.plots.plot_utils.calculate_point_estimate("mode", val) for val in vals]
As can be seen, this uses/abuses a few back-end Arviz functions. It would seem better to have a cleaner API to access the point estimates that can be obtained in the hdi plots parameters point_estimate
argument, such as mean
, mode
, median
. Something like: az.get_point_estimate(point_estimate='mode', var_names=["μ", "σ", "ξ"])
.
- Getting the variance-covariance matrix of the estimates requires a pandas interface:
trace["posterior"].drop_vars("z_p").to_dataframe().cov().round(6)
Again, this is a bit non-bayesian, but is useful for comparison with results from other sources. So something like: az.get_var_covar(var_names=["μ", "σ", "ξ"])
.
-
Again, looking at that
InferenceData
accessor to the xarraydrop_vars
, it would be neat if there was a comparableget_vars
which returned the results for the selected variables - this functionality is already built-in of course, as is used through the arguments to many of the plot functions. But something directly like:trace["posterior"].get_vars(["μ", "σ", "ξ"])
would be helpful. -
More minor: It seems that to examine the prior predictive checks, we should now use the
plot_posterior
function. I suspect aplot_prior
wrapper would be more logical and more readable code.
az.plot_posterior(
prior_pc, group="prior", var_names=["μ", "σ", "ξ"], hdi_prob="hide", point_estimate=None
);
Intro
such as mean, mode, median
mean and median are already available via xarray, so I think we should not reimplement those. Doing idata.posterior[["subset", "of vars", "if desired"]].median()
already works, and you can use dim=("chain", "draw")
to specify which dimensions to reduce by name. Same for mean. Reference: https://xarray.pydata.org/en/stable/generated/xarray.Dataset.median.html, https://xarray.pydata.org/en/stable/generated/xarray.Dataset.mean.html.
It might be interesting to try and make the mode we use in some plots available with a similar api. It should not be too difficult given it's already implemented if using apply_ufunc
or wrap_xarray_ufunc
carefully
Getting the variance-covariance matrix of the estimates requires a pandas interface:
Never used that but this looks general enough to live in xarray directly and seems to be somewhat available already: https://xarray.pydata.org/en/stable/generated/xarray.cov.html. If this is not good enough we should try and push those improvements directly to xarray.
Again, looking at that InferenceData accessor to the xarray drop_vars, it would be neat if there was a comparable get_vars which returned the results for the selected variables - this functionality is already built-in of course, as is used through the arguments to many of the plot functions. But something directly like: trace["posterior"].get_vars(["μ", "σ", "ξ"]) would be helpful.
Is this the same (or would be solved satisfactorly) as https://github.com/arviz-devs/arviz/pull/1725 (that fixes https://github.com/arviz-devs/arviz/issues/1469)?
It seems that to examine the prior predictive checks, we should now use the plot_posterior function. I suspect a plot_prior wrapper would be more logical and more readable code.
The name might not have been the best choice but plot_posterior
, like plot_density
and several other functions can be used on any group, not necessarily the posterior or the prior groups. Adding plot_prior
would probably mean open season for plot_posterior_predictive, plot_sample_stats... which I believe would end up being even more confusing.
Also note that plot_ppc
can be used for either posterior or prior predictive checks in the comparing generated distributions to observed data, plot_posterior
plots distributions, and so it can be used to plot the prior distributions and also consequently for prior predictive checks, but is not necessarily tied to prior predictive checks not includes the observed data in the plot.
Practical remarks
- We should consider making the mode computation available. The main con on that I can see is that it will probably be towards the low priority end and might take time to get even if it is not too many lines of code. I don't think there would be opposition to the change but might be wrong
- Can you try the xarray.cov thing and let us know how it goes?
- I have not had much time lately and have not gone back to finish the
extract_dataset
PR. If you feel up for it, feel free to take the work there and push it after the finish line - Maybe we should consider renaming
plot_posterior
?