mcmc-monitor icon indicating copy to clipboard operation
mcmc-monitor copied to clipboard

update ESS calculation code to be consistent with bayes-kit

Open magland opened this issue 2 years ago • 18 comments

There are some unresolved questions here about what should be the exact formulae used. (will need to discuss with @bob-carpenter)

Here is the bayes-kit implemention, which may have some issues:

https://github.com/flatironinstitute/bayes-kit/blob/main/bayes_kit/ess.py

Here's is the current mcmc-monitor implementation, which may need to be adjusted:

https://github.com/flatironinstitute/mcmc-monitor/blob/459e8b6814c745bbf7681f3035b409177a540057/src/MCMCMonitorDataManager/stats/ess.ts

The critical functions in question are

first_neg_pair_start https://github.com/flatironinstitute/mcmc-monitor/blob/459e8b6814c745bbf7681f3035b409177a540057/src/MCMCMonitorDataManager/stats/ess.ts#L72-L82

and ess_imse https://github.com/flatironinstitute/bayes-kit/blob/22a3e9ff31f2268f47e13a737dc57c81a26ae917/bayes_kit/ess.py#L99-L135

They way mcmc-monitor does it now, the sigma_sq_hat (I think aka IAT) is never going to be less than 1, which I believe should be a desirable property. But of course we'll want to be consistent with bayes-kit.

tagging: @jsoules @WardBrian

magland avatar Feb 14 '23 13:02 magland

I believe you will want to be able to report an ESS greater than the number of draws. This assumes the function people are interested in is the mean, but in that case anticorrelated draws (commonly produced by HMC) are actually better than independent samples

WardBrian avatar Feb 14 '23 14:02 WardBrian

I believe you will want to be able to report an ESS greater than the number of draws. This assumes the function people are interested in is the mean, but in that case anticorrelated draws (commonly produced by HMC) are actually better than independent samples

I guess that makes sense. But from what I understand, bayes-kit will essentially compute the area under the main (positve) lobe... and the only time you get IAT<1 is when there is some technical circumstance where a bit of the negative dip after the main lobe is counted as part of the area. The present mcmc-monitor code doesn't ever include that negative piece. I would be surprised if the bayes-kit version is reliably picking up anti-correlated behavior, but perhaps I am misunderstanding something.

magland avatar Feb 14 '23 14:02 magland

I can't really speak on the bayes_kit code, just the behavior I expect from tools like stansummary (which I believe is calculated with this code)

WardBrian avatar Feb 14 '23 14:02 WardBrian

The bayes-kit code just implements the mathematical definition from the paper as described in the Stan reference manual.

The work in Stan's implementation is delegated to the function defined in stan/analyze/mcmc/compute_effective_sample_size.hpp

There's a comment starting at line 98 which says how this is adjusted for antiautocorrleation. The monotonicity condition is implemented differently, though it may work out to the same thing.

P.S. We need to update the doc in our reference manual description of what's being computed to match the actual computation. What happened, I believe, is that we used to use the standard definition from Geyer's paper and then it was updated to deal with the kinds of antiautocorrelated chains we see with Hamiltonian Monte Carlo.

bob-carpenter avatar Feb 14 '23 15:02 bob-carpenter

I guess we should wait until the bayes_kit implementation is solidified, and then we can just copy it exactly. For now there's a working calculation for ESS that should be pretty close.

magland avatar Feb 14 '23 15:02 magland

I sat down and did the algebra and I think the confusion is the offset on the pairwise positivity constrain. I originally had the wrong implementation there using pairs (1, 2), (3, 4), .... Instead, what you want to do is take pairs (0, 1), (2, 3), ...

For example, If we have a simple time series model like an AR(1) process with autocorrrelation rho in (-1,1), then we know the overall autocorrelations are

ac[-2] = rho^2
ac[-1] = rho
ac[0] = 1
ac[1] = rho
ac[2] = rho^2

Our estimator for integrated autocorrelation time is

IAT = ... + ac[-2] + ac[-1] + ac[0] + ac[1] + ac[2] + ...

but we have symmetry, ac[n] = ac[-n], so we evaluate with just the positive terms

IAT = -1 + 2 * [ ac[0] + ac[1] + ... ]

where the -1 term is to avoid double counting the lag-zero autocorrelation ac[0] = 1.

Now, if we have rho = -0.9, we'll get something like this:

IAT = -1 + 2 * [ (1 + -0.9) + (.81 + -.72) + ... ]
    < 1

Here's the behavior on the current branch, where sample_ar1 uses (lag 1) autocorrleation rho and generates a sample of size N.

>>> y = sample_ar1(rho = 0, N = 1000)
>>> len(y)
1000
>>> ess(y)
971.3854635513478
>>> 
>>> y = sample_ar1(rho = 0.5, N = 1000)
>>> ess(y)
376.41893467549767
>>> 
>>> y = sample_ar1(rho = -0.5, N = 1000)
>>> ess(y)
2999.0927300430913

So you can see in the last case that estimated ESS > N.

bob-carpenter avatar Feb 14 '23 22:02 bob-carpenter

Thanks @bob-carpenter. Okay, I will adjust our ess.ts to exactly match what's in your ess.py.

One minor note. Shouldn't the comment on first_neg_pair_start be adjusted?

https://github.com/flatironinstitute/bayes-kit/blob/40a34129a6313056bc4fedc42606578cee6dff77/bayes_kit/ess.py#L65-L67

Right now it reads

Return: index of first element whose sum with following element is negative, or the number of elements if there is no such element

But really it should be "Index of the first even-indexed element...

magland avatar Feb 15 '23 14:02 magland

@bob-carpenter There is still a typo in ess_imse where prevmin should be min_prev, which I guess will be resolved with PR https://github.com/flatironinstitute/bayes-kit/pull/16

But also, could you double-check the indexing there, because in this function (ess_imse) you are taking pairs (1, 2) (3, 4) etc. which seems inconsistent with first_neg_pair_start. (but maybe that's okay?)

magland avatar Feb 15 '23 15:02 magland

That's a much more accurate way to doc, so I'll update. I'll also check code again. It should be checking pairs (0, 1), (1, 2), ....

The estimator without the montonic-downward constraint on pair sums will have a bit more variance, but should still be OK.

bob-carpenter avatar Feb 15 '23 15:02 bob-carpenter

Sorry, Bob, I'm still confused. You've said:

That's a much more accurate way to doc, so I'll update.

presumably referring to Jeremy's note that

really it should be "Index of the first even-indexed element...

But then you go on to say:

It should be checking pairs (0, 1), (1, 2),

which looks like it would be every pair, not just the even-indexed ones?

jsoules avatar Feb 15 '23 15:02 jsoules

@jsoules My understanding is that "That's a much more accurate way to doc, so I'll update." refers to the first_neg_pair_start function, whereas the rest of Bob's paragraph refers to ess_imse.

magland avatar Feb 15 '23 15:02 magland

Ack, I mean pairs (0, 1), (2, 3), .... Doc is hard, especially before coffee.

bob-carpenter avatar Feb 15 '23 15:02 bob-carpenter

I noticed that the current ESS computation seems to ignore the Rhat part. For example, https://flatironinstitute.github.io/mcmc-monitor/?s=https://mcmc-monitor-proxy.herokuapp.com/s/c76d31a29d3d08e9a536&webrtc=0#/run/wu7p5aeu shows total ESS as 1140, but as Rhat is 6.95, the total ESS should be less than 4. It seems the current total ESS is just a sum of individual chain ESS's, but it should take into account if the chains are not mixing (detected by Rhat) image

avehtari avatar May 09 '23 08:05 avehtari

That's exactly what it's doing now. I need to implement the more sophisticated R-hat estimator. That will just be a plug-in change to all the visualizations and plots.

bob-carpenter avatar May 09 '23 11:05 bob-carpenter

Once R-hat is implemented correctly, what should be the formula for the total ESS? Is it the sum of the individual ess's divided by R-hat, or something like that?

magland avatar May 09 '23 14:05 magland

See Section 3.2 in https://doi.org/10.1214/20-BA1221. For combining the information from the individual chains and Rhat the key equation is (3.10). There is a correct Python implementation in ArviZ package

avehtari avatar May 09 '23 14:05 avehtari

Thanks @avehtari. I'll read through this carefully (quite technical); I think it's important to get it right. I'm on the fence about whether or not to wait until this is implemented in bayes-kit. I'll chat with @jsoules and @bob-carpenter.

magland avatar May 10 '23 12:05 magland

The current BayesKit only has the original R-hat algorithm and the ESS estimator that just sums across chains. In this situation the ESS estimator will be biased to the high side when R-hat >> 1.

There are three refined estimators of R-hat and different estimators of ESS:

  • Splitting chains. This is useful for one chain and it's cheap. I'd like to see a real model where it helps with multiple chains. The papers only have made-up examples.

  • Ranks. Reducing values to ranks helps with tails.

  • Using R-hat variance estimator for ESS. This penalizes ESS if R-hat is high. This requires plugging multiple chains into ESS.

I won't be able to get to this for a couple weeks. In terms of functionality, a new R-hat/ESS estimator should be pluggable at any point.

If you think it's critical to get a better R-hat/ESS estimator, you could try ArviZ. They have implemented a CmdStanPy interface to read in data, which is the trickiest part of ArviZ. I didn't want to include a dependency from BayesKit because ArviZ is very heavy in terms of both data structures and dependencies.

bob-carpenter avatar May 10 '23 13:05 bob-carpenter