numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Add Pareto Smoothed Importance Sampling (PSIS) diagnostic method

Open asmuzsoy opened this issue 1 year ago • 4 comments

The Pareto Smoothed Importance Sampling $\hat{k}$ diagnostic (as described by Yao et al. (2018)) can be used to determine if a surrogate posterior determined with variational inference is a good approximation of the true posterior. It would be great if there was a built-in method to calculate $\hat{k}$ in NumPyro (as there is in Pyro).

I have created a preliminary implementation for my work (seen here) that calculates the log importance ratios and uses the arviz.stats.stats._psislw() method to calculate $\hat{k}$. It can be easily generalized for use with any model and guide.

Any thoughts would be much appreciated!

asmuzsoy avatar May 21 '24 23:05 asmuzsoy

That's very cool! I didn't know there was a nice diagnostic for how "good" variational approximations are. Just wondering if this could be a nice contribution to arviz itself so the method can also be used for other inference frameworks, e.g., Stan?

tillahoffmann avatar Aug 12 '24 17:08 tillahoffmann

There is a method in arviz (arviz.stats.stats._psislw()) which calculates the test statistic given the log importance ratios. But to calculate the log importance ratios you need the numpyro model, guide, best fit parameters, and the samples that you're trying to evaluate, thus my suggestion for a new method that would do this all in one go.

asmuzsoy avatar Aug 12 '24 19:08 asmuzsoy

Do you know if it's possible to evaluate those other quantities from the arviz InferenceData object?

tillahoffmann avatar Aug 12 '24 20:08 tillahoffmann

I don't think so, unless arviz.InferenceData somehow incorporates a numpyro model and guide?

asmuzsoy avatar Aug 16 '24 18:08 asmuzsoy