numpyro
numpyro copied to clipboard
Add Pareto Smoothed Importance Sampling (PSIS) diagnostic method
The Pareto Smoothed Importance Sampling $\hat{k}$ diagnostic (as described by Yao et al. (2018)) can be used to determine if a surrogate posterior determined with variational inference is a good approximation of the true posterior. It would be great if there was a built-in method to calculate $\hat{k}$ in NumPyro (as there is in Pyro).
I have created a preliminary implementation for my work (seen here) that calculates the log importance ratios and uses the arviz.stats.stats._psislw() method to calculate $\hat{k}$. It can be easily generalized for use with any model and guide.
Any thoughts would be much appreciated!
That's very cool! I didn't know there was a nice diagnostic for how "good" variational approximations are. Just wondering if this could be a nice contribution to arviz itself so the method can also be used for other inference frameworks, e.g., Stan?
There is a method in arviz (arviz.stats.stats._psislw()) which calculates the test statistic given the log importance ratios. But to calculate the log importance ratios you need the numpyro model, guide, best fit parameters, and the samples that you're trying to evaluate, thus my suggestion for a new method that would do this all in one go.
Do you know if it's possible to evaluate those other quantities from the arviz InferenceData object?
I don't think so, unless arviz.InferenceData somehow incorporates a numpyro model and guide?