arviz icon indicating copy to clipboard operation
arviz copied to clipboard

[WIP] Add power-scaling sensitivity analysis diagnostic

Open n-kall opened this issue 3 years ago • 5 comments

This PR add functionality for power-scaling sensitivity analysis. Details of the approach are described in Kallioinen et al. (2022) https://arxiv.org/abs/2107.14054 and the corresponding R package (with additional functionality) is available at https://github.com/n-kall/priorsense

This PR adds the base functionality required to calculate the sensitivity diagnostic, which I have labelled 'psens' (feel free to change the naming). This diagnostic can be used to check whether there is sensitivity of the posterior to changes in the prior or likelihood without needing to refit the model.

In order to calculate the diagnostic, evaluations of the log_prior and log_likelihood at the points of the posterior draws are required. Currently, the log_likelihood is available, but the log_prior is not saved by PyMC. So, for psens to function properly, it would require a way to save the log_prior evaluations when fitting a model in PyMC.


:books: Documentation preview :books:: https://arviz--2093.org.readthedocs.build/en/2093/

n-kall avatar Aug 11 '22 07:08 n-kall

@OriolAbril we probably should add prior_loglikelihood to InferenceData schema?

ahartikainen avatar Aug 11 '22 08:08 ahartikainen

@OriolAbril Thanks for all the comments and suggestions! I'll clean it up shortly

n-kall avatar Aug 18 '22 09:08 n-kall

I realise I left out an example of the imagined usage

import arviz as az
import numpy as np
import pymc as pm

draws = 2000
chains = 4

data = {"y" = np.array([1,2,3,4,5])}

with pm.Model() as model:
    mu = pm.Normal("mu", mu=0, sigma=1)
    sigma = pm.HalfNormal("sigma", sigma=2.5)
    pm.Normal("obs", mu=mu, sigma=sigma, observed=data["y"])
    post = pm.sample(draws, chains=chains)

az.psens(post, component="likelihood")

n-kall avatar Aug 18 '22 11:08 n-kall

I left two still a bit conceptual comments in the discussion above, but it already looks good. I'll leave some more specific comments below to start getting the PR merge ready

OriolAbril avatar Aug 19 '22 17:08 OriolAbril

I updated the CJS calculation to no longer use nansum, but where instead: np.log2(x, out=np.zeros_like(x), where=(x != 0)). Do you think this is sufficient?

n-kall avatar Sep 19 '22 08:09 n-kall

I've made some updates regarding the extraction of the log_likelihood and log_prior. Below is a working example using CmdStanPy.

I think it might make sense to include log_prior as a group in inference data, just as log_likelihood is stored. The current code assumes there is a group called 'log_prior', which is manually added in the example.

Stan model (example.stan):

data {

  int N;
  vector[N] x;
}

parameters {

  real mu;
  real<lower=0> sigma;

}

model {

  target += normal_lpdf(mu | 0, 1);
  target += normal_lpdf(sigma | 0, 2.5);

  target += normal_lpdf(x | mu, sigma);

}

generated quantities {

  vector[2] lprior;

  vector[N] log_lik;

  lprior[1] = normal_lpdf(mu | 0, 1);
  lprior[2] = normal_lpdf(sigma | 0, 2.5);

  for (n in 1:N) {
    log_lik[n] = normal_lpdf(x[n] | mu, sigma);
  }

}
import cmdstanpy as stan
import arviz as az
import numpy as np

model = stan.CmdStanModel(stan_file = "example.stan")

dat = {"sigma" : 1, "N" : 5, "x" : [1, 0.5, 3, 5, 10]}
fit = model.sample(data=dat)
idata = az.from_cmdstanpy(fit)
lprior = az.extract(idata, var_names="lprior", group="posterior", combined=False)

idata.add_groups(
    log_prior={"lprior": lprior}
)

# prior sensitivity (for all priors)
az.stats.psens(idata, component="prior", var_names=["mu", "sigma"])

# likelihood sensitivity (for joint likelihood)
az.stats.psens(idata, component="likelihood", var_names=["mu", "sigma"])

# prior sensitivity for prior on mu
az.stats.psens(idata, component="prior", var_names=["mu", "sigma"], selection=[0])

# likelihood sensitivity for single observation
az.stats.psens(idata, component="likelihood", var_names=["mu", "sigma"], selection=[1])

n-kall avatar Jan 11 '23 10:01 n-kall

This PR is definitely a very good reason to include log_prior as a group in inference data. It seems you have some pylint issues to solve

aloctavodia avatar Jan 11 '23 13:01 aloctavodia

Isn't it enough to have a single total log prior value per sample? If so I would put it in sample stats like lp

OriolAbril avatar Jan 12 '23 19:01 OriolAbril

It can be useful to have the lprior as an array with different contributions to the joint log prior, so that it is possible to check the sensitivity to changing a subset of the priors. Currently this works with the 'selection' argument. The lprior is then an array of N_priors x N_draws. I'm not so familiar with the InferenceData scheme, so wherever this would make sense to be stored, just let me know and I can adjust the functions accordingly

n-kall avatar Jan 13 '23 13:01 n-kall

Codecov Report

Attention: 3 lines in your changes are missing coverage. Please review.

Comparison is base (22c0dcb) 86.70% compared to head (432956a) 86.74%.

Files Patch % Lines
arviz/stats/stats_utils.py 80.00% 2 Missing :warning:
arviz/stats/stats.py 98.18% 1 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2093      +/-   ##
==========================================
+ Coverage   86.70%   86.74%   +0.04%     
==========================================
  Files         122      122              
  Lines       12640    12703      +63     
==========================================
+ Hits        10959    11019      +60     
- Misses       1681     1684       +3     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

codecov[bot] avatar Jan 13 '23 13:01 codecov[bot]

I think it would be best to store it in an independent groups with the same variables as in the posterior and prior groups, but where instead of samples we store pointwise log prior values. This is very similar to what we already do for the pointwise log likelihood where we store the pointwise log likelihood values in their own group using the same variables names as in the posterior predictive and observed data groups.

To get the per sample total log prior to_stacked_dataarray can be used. For example:

After loading the rugby dataset for example, the posterior is this:

idata = az.load_arviz_data("rugby")
idata.posterior

an xarray Dataset (dict like structure) with multiple variables, sharing some dimensions but each with its own independent shape:

<xarray.Dataset>
Dimensions:    (chain: 4, draw: 500, team: 6)
Coordinates:
  * chain      (chain) int64 0 1 2 3
  * draw       (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499
  * team       (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
Data variables:
    home       (chain, draw) float64 0.1642 0.1162 0.09299 ... 0.148 0.2265
    intercept  (chain, draw) float64 2.893 2.941 2.939 ... 2.951 2.903 2.892
    atts_star  (chain, draw, team) float64 0.1673 0.04184 ... -0.4652 0.02878
    defs_star  (chain, draw, team) float64 -0.03638 -0.04109 ... 0.7136 -0.0649
    sd_att     (chain, draw) float64 0.4854 0.1438 0.2139 ... 0.2883 0.4591
    sd_def     (chain, draw) float64 0.2747 1.033 0.6363 ... 0.5574 0.2849
    atts       (chain, draw, team) float64 0.1063 -0.01913 ... -0.2911 0.2029
    defs       (chain, draw, team) float64 -0.06765 -0.07235 ... 0.5799 -0.1986

To get them as a 2d array we can do:

az.extract(idata).to_stacked_array("latent_var", sample_dims=("sample",))
<xarray.DataArray 'home' (sample: 2000, latent_var: 28)>
array([[ 0.16416114,  2.89297934,  0.16729464, ...,  0.2082224 ,
         0.55365865, -0.22098263],
...
       [ 0.22646978,  2.89236378,  0.0439508 , ...,  0.13586028,
         0.57993105, -0.19855837]])
Coordinates:
  * sample      (sample) object MultiIndex
  * chain       (sample) int64 0 0 0 0 0 0 0 0 0 0 0 0 ... 3 3 3 3 3 3 3 3 3 3 3
  * draw        (sample) int64 0 1 2 3 4 5 6 7 ... 493 494 495 496 497 498 499
  * latent_var  (latent_var) object MultiIndex
  * variable    (latent_var) object 'home' 'intercept' ... 'defs' 'defs'
  * team        (latent_var) object nan nan 'Wales' ... 'Italy' 'England'

And adding .sum("latent_var") would get the total per sample log prior if these were pointwise log prior contributions.

Imo, choosing of variables to include (or subsets of them) should happen in the first object which is what the users really know, these are the variables they define in their Stan or PyMC models.

I am not sure about what API to use though. For model comparison, we chose to force users to select a single var_name (unlike all other functions that accept lists too) because ArviZ a priori doesn't know how to combine the values into pointwise log likelihood if there are multiple variables, as we could consider loo with multivariate observations, loo with univariate observations, even logo... Here however, we only want the per sample values, so summing over everything but the sampling dimensions sounds like a good default.

OriolAbril avatar Jan 21 '23 00:01 OriolAbril

imatge

There is some env dependent effect that triggers a recursion error. Will continue looking into it tomorrow

OriolAbril avatar Nov 30 '23 18:11 OriolAbril

It seems xarray objects can't be kwargs with apply_ufunc, they should either be converter to numpy arrays or used as positional arguments.

Also, here is the link to the docstring preview: https://arviz--2093.org.readthedocs.build/en/2093/api/generated/arviz.psens.html. The test aren't super exhaustive but I think they are good for now.

OriolAbril avatar Nov 30 '23 19:11 OriolAbril

@OriolAbril Thanks for all the help with this PR! Greatly appreciated!

n-kall avatar Dec 01 '23 13:12 n-kall