arviz
arviz copied to clipboard
[WIP] Add power-scaling sensitivity analysis diagnostic
This PR add functionality for power-scaling sensitivity analysis. Details of the approach are described in Kallioinen et al. (2022) https://arxiv.org/abs/2107.14054 and the corresponding R package (with additional functionality) is available at https://github.com/n-kall/priorsense
This PR adds the base functionality required to calculate the sensitivity diagnostic, which I have labelled 'psens' (feel free to change the naming). This diagnostic can be used to check whether there is sensitivity of the posterior to changes in the prior or likelihood without needing to refit the model.
In order to calculate the diagnostic, evaluations of the log_prior and log_likelihood at the points of the posterior draws are required. Currently, the log_likelihood is available, but the log_prior is not saved by PyMC. So, for psens to function properly, it would require a way to save the log_prior evaluations when fitting a model in PyMC.
:books: Documentation preview :books:: https://arviz--2093.org.readthedocs.build/en/2093/
@OriolAbril we probably should add prior_loglikelihood to InferenceData schema?
@OriolAbril Thanks for all the comments and suggestions! I'll clean it up shortly
I realise I left out an example of the imagined usage
import arviz as az
import numpy as np
import pymc as pm
draws = 2000
chains = 4
data = {"y" = np.array([1,2,3,4,5])}
with pm.Model() as model:
mu = pm.Normal("mu", mu=0, sigma=1)
sigma = pm.HalfNormal("sigma", sigma=2.5)
pm.Normal("obs", mu=mu, sigma=sigma, observed=data["y"])
post = pm.sample(draws, chains=chains)
az.psens(post, component="likelihood")
I left two still a bit conceptual comments in the discussion above, but it already looks good. I'll leave some more specific comments below to start getting the PR merge ready
I updated the CJS calculation to no longer use nansum, but where instead: np.log2(x, out=np.zeros_like(x), where=(x != 0)). Do you think this is sufficient?
I've made some updates regarding the extraction of the log_likelihood and log_prior. Below is a working example using CmdStanPy.
I think it might make sense to include log_prior as a group in inference data, just as log_likelihood is stored. The current code assumes there is a group called 'log_prior', which is manually added in the example.
Stan model (example.stan):
data {
int N;
vector[N] x;
}
parameters {
real mu;
real<lower=0> sigma;
}
model {
target += normal_lpdf(mu | 0, 1);
target += normal_lpdf(sigma | 0, 2.5);
target += normal_lpdf(x | mu, sigma);
}
generated quantities {
vector[2] lprior;
vector[N] log_lik;
lprior[1] = normal_lpdf(mu | 0, 1);
lprior[2] = normal_lpdf(sigma | 0, 2.5);
for (n in 1:N) {
log_lik[n] = normal_lpdf(x[n] | mu, sigma);
}
}
import cmdstanpy as stan
import arviz as az
import numpy as np
model = stan.CmdStanModel(stan_file = "example.stan")
dat = {"sigma" : 1, "N" : 5, "x" : [1, 0.5, 3, 5, 10]}
fit = model.sample(data=dat)
idata = az.from_cmdstanpy(fit)
lprior = az.extract(idata, var_names="lprior", group="posterior", combined=False)
idata.add_groups(
log_prior={"lprior": lprior}
)
# prior sensitivity (for all priors)
az.stats.psens(idata, component="prior", var_names=["mu", "sigma"])
# likelihood sensitivity (for joint likelihood)
az.stats.psens(idata, component="likelihood", var_names=["mu", "sigma"])
# prior sensitivity for prior on mu
az.stats.psens(idata, component="prior", var_names=["mu", "sigma"], selection=[0])
# likelihood sensitivity for single observation
az.stats.psens(idata, component="likelihood", var_names=["mu", "sigma"], selection=[1])
This PR is definitely a very good reason to include log_prior as a group in inference data. It seems you have some pylint issues to solve
Isn't it enough to have a single total log prior value per sample? If so I would put it in sample stats like lp
It can be useful to have the lprior as an array with different contributions to the joint log prior, so that it is possible to check the sensitivity to changing a subset of the priors. Currently this works with the 'selection' argument. The lprior is then an array of N_priors x N_draws. I'm not so familiar with the InferenceData scheme, so wherever this would make sense to be stored, just let me know and I can adjust the functions accordingly
Codecov Report
Attention: 3 lines in your changes are missing coverage. Please review.
Comparison is base (
22c0dcb) 86.70% compared to head (432956a) 86.74%.
| Files | Patch % | Lines |
|---|---|---|
| arviz/stats/stats_utils.py | 80.00% | 2 Missing :warning: |
| arviz/stats/stats.py | 98.18% | 1 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #2093 +/- ##
==========================================
+ Coverage 86.70% 86.74% +0.04%
==========================================
Files 122 122
Lines 12640 12703 +63
==========================================
+ Hits 10959 11019 +60
- Misses 1681 1684 +3
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
I think it would be best to store it in an independent groups with the same variables as in the posterior and prior groups, but where instead of samples we store pointwise log prior values. This is very similar to what we already do for the pointwise log likelihood where we store the pointwise log likelihood values in their own group using the same variables names as in the posterior predictive and observed data groups.
To get the per sample total log prior to_stacked_dataarray can be used. For example:
After loading the rugby dataset for example, the posterior is this:
idata = az.load_arviz_data("rugby")
idata.posterior
an xarray Dataset (dict like structure) with multiple variables, sharing some dimensions but each with its own independent shape:
<xarray.Dataset>
Dimensions: (chain: 4, draw: 500, team: 6)
Coordinates:
* chain (chain) int64 0 1 2 3
* draw (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
* team (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
Data variables:
home (chain, draw) float64 0.1642 0.1162 0.09299 ... 0.148 0.2265
intercept (chain, draw) float64 2.893 2.941 2.939 ... 2.951 2.903 2.892
atts_star (chain, draw, team) float64 0.1673 0.04184 ... -0.4652 0.02878
defs_star (chain, draw, team) float64 -0.03638 -0.04109 ... 0.7136 -0.0649
sd_att (chain, draw) float64 0.4854 0.1438 0.2139 ... 0.2883 0.4591
sd_def (chain, draw) float64 0.2747 1.033 0.6363 ... 0.5574 0.2849
atts (chain, draw, team) float64 0.1063 -0.01913 ... -0.2911 0.2029
defs (chain, draw, team) float64 -0.06765 -0.07235 ... 0.5799 -0.1986
To get them as a 2d array we can do:
az.extract(idata).to_stacked_array("latent_var", sample_dims=("sample",))
<xarray.DataArray 'home' (sample: 2000, latent_var: 28)>
array([[ 0.16416114, 2.89297934, 0.16729464, ..., 0.2082224 ,
0.55365865, -0.22098263],
...
[ 0.22646978, 2.89236378, 0.0439508 , ..., 0.13586028,
0.57993105, -0.19855837]])
Coordinates:
* sample (sample) object MultiIndex
* chain (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3
* draw (sample) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
* latent_var (latent_var) object MultiIndex
* variable (latent_var) object 'home' 'intercept' ... 'defs' 'defs'
* team (latent_var) object nan nan 'Wales' ... 'Italy' 'England'
And adding .sum("latent_var") would get the total per sample log prior if these were pointwise log prior contributions.
Imo, choosing of variables to include (or subsets of them) should happen in the first object which is what the users really know, these are the variables they define in their Stan or PyMC models.
I am not sure about what API to use though. For model comparison, we chose to force users to select a single var_name (unlike all other functions that accept lists too) because ArviZ a priori doesn't know how to combine the values into pointwise log likelihood if there are multiple variables, as we could consider loo with multivariate observations, loo with univariate observations, even logo... Here however, we only want the per sample values, so summing over everything but the sampling dimensions sounds like a good default.
There is some env dependent effect that triggers a recursion error. Will continue looking into it tomorrow
It seems xarray objects can't be kwargs with apply_ufunc, they should either be converter to numpy arrays or used as positional arguments.
Also, here is the link to the docstring preview: https://arviz--2093.org.readthedocs.build/en/2093/api/generated/arviz.psens.html. The test aren't super exhaustive but I think they are good for now.
@OriolAbril Thanks for all the help with this PR! Greatly appreciated!