arviz icon indicating copy to clipboard operation
arviz copied to clipboard

Leave Future Out Cross Validation

Open jessegrabowski opened this issue 2 years ago • 13 comments

Hello,

I am interested in adapting the refitting wrappers to implement LFO-CV, as described by Bürkner, Gabry, and Vehtari (2020), with the goal of cleaning everything up, writing unit tests, and submitting a PR. I am following the R implementation the authors provide in the paper's Github repo.

I have a somewhat working prototype, which I put into a notebook here. The "exact" method, where the model is completely re-fit every time step, matches the R output very closely, so I am feeling pretty good that I am roughly on the right track. The much more important approximate method, however, seems to have a problem. The k-hat parameters seem to be much too high, and as a result the model re-fits too many times. The R implementation requires only a single re-fit for the data I present in the notebook, whereas my implementation requires more than a dozen.

I am positive I have made a mistake in my computations, and I suspect it is in the psis_log_weights function:

def psis_log_weights(ll_predict):
    # TODO: Support user-supplied name on time dim

    # Equation 10
    log_ratio = ll_predict.sum('time')
    lr_stacked = log_ratio.stack(__sample__ = ['chain', 'draw'])
    
    # Method "identity" gives closest match to loo::psis output in R, is it right?
    reff = az.ess(1 / np.exp(log_ratio), method="identity", relative=True).x.values
    
    # TODO: loo::psis doesn't need lr_stacked to be negative? See:
    #  https://github.com/paul-buerkner/LFO-CV-paper/blob/master/sim_functions.R#L231
    log_weights, k_hat = az.psislw(-lr_stacked, reff)
    
    return log_weights, k_hat

I am worried I'm not familiar enough with the required inputs to az.ess and az.psislw, as well as how this differ from the corresponding functions in the R loo package, to see where I am going wrong. I am hoping the community here will spot something dumb I did right away.

Thanks for any help you can give!

jessegrabowski avatar Jul 11 '22 14:07 jessegrabowski

Hi, have you tried the good old 'print all steps' in R and Python?

ahartikainen avatar Jul 11 '22 14:07 ahartikainen

Haha yeah, I stepped through them both in debug mode side-by-side and took notes. That's the only way I got this far.

I have a suspicion I'm being a knucklehead by trying to follow the R code too closely. The places I'm confused are:

  1. When I compute the reff, I use the un-stacked log ratio (because az.ess expects a chain dimension), while az.pisislw expects chain and draw to be stacked into sample. Does this create inconsistencies?
  2. It's not super clear to me what the "method" argument in az.ess is doing, and there doesn't seem to be an equivalent argument in e.g. loo::relative_eff. I went with "identity" just because it made the outputs match, but not for any principled reason.
  3. Where is this 1 / np.exp(log_ratio) coming from? I just blindly copied it from R because it makes the log_weights match more closely. But in az.loo there's nothing of the sort for computing the relative ess. Here's the original R code, in both cases logratio is a (chains * samples, 1) vector made by summing the out-of-sample log-likelihoods across the time dimension.
  r_eff_lr <- loo::relative_eff(logratio, chain_id = chain_id)
  r_eff <- loo::relative_eff(1 / exp(logratio), chain_id = chain_id)
  1. For that matter, az.loo doesn't use the relative=True parameter -- is there any reason for that?
  2. az.psislw requires a negative log-likelihood, but loo::psis does not?

jessegrabowski avatar Jul 11 '22 15:07 jessegrabowski

Hi, thanks for getting this rolling @jessegrabowski! Trying to catch up.

  1. There should not be any inconsistencies due to that. ess and the different methods are tested quite thoroughly too I believe also with comparisons to R

  2. The method argument defines how the effective sample size is computed. There are different ways to go about this, such as "splitting" the 4 chains/1000 draws into 8 chains/500 draws, or using ranks to compute the ess instead of raw values (see more details for example on https://arxiv.org/abs/1903.08008, also, the different methods match different functions in the posterior package).

    After a 1 min skim of the loo::relative_eff function it looks like the computations there match the identity method (aka no split, no ranks...). This also makes sense conceptually as I assume here we don't really want to use ess as a diagnostic but instead assume we have samples from a converged fit and was an estimation of relative ess as precise as possible.

  3. IIRC, the ress parameter in psis is there to precalculate some constants and make psis more efficient. I think ArviZ currently takes the mean over all ess from the posterior variables. It is probably a better idea to use the ess directly for the quantity on which we will use psis which seems to be what is happening here.

  4. (6 can't get the number in the preview to be a 6) I don't think there is any reason for that. It might even be because loo was written before ess had a relative argument and it wasn't updated after that.

  5. (7) I haven't yet gone through all the code in LFO in detail, but from the docs both az.psislw and loo::psis behave similarly. Their expected input is the log_ratios which for "standard" loo-psis are the negative pointwise log likelihood values. If the log ratios are defined diferently here then the negative might not be necessary anymore. Ref: http://mc-stan.org/loo/reference/psis.html#arguments, quote:

    An array, matrix, or vector of importance ratios on the log scale (for PSIS-LOO these are negative log-likelihood values).

    so az.loo passes -log_lik as log_ratios to psislw, but psislw can be used to perform pareto smoothed importance sampling on any array of log ratios, that "constraint" is part of loo not of psis which is why it is done in az.loo and not in az.psislw.

Will try to go over the notebook at some point this week. Please continue asking more specific questions here if you want me to focus on something specific.

OriolAbril avatar Aug 02 '22 22:08 OriolAbril

@jessegrabowski I wonder if this is an issue related to #2148 I posted a while ago.

I implemented PSIS-LFO-CV a while ago at work, and eventually rolled my own psislw function that matched R's method exactly.

Happy to dig into this in more detail too, as it would be useful functionality.

cmgoold avatar Jul 21 '23 05:07 cmgoold

Could be! I've let this project fall to the wayside, but I need to come back to it in the coming weeks/months, so I'm keen to collaborate on it. Could you have a look at that notebook I posted and see if anything strikes you as obviously wrong? Maybe try it your modified psislw function to see if that reduces the number of refits to match R? I think it was only 2-3, but it's been a while.

jessegrabowski avatar Jul 21 '23 06:07 jessegrabowski

@jessegrabowski Yes, I can take a look!

cmgoold avatar Jul 21 '23 17:07 cmgoold

@jessegrabowski I wonder if this is an issue related to #2148 I posted a while ago.

I implemented PSIS-LFO-CV a while ago at work, and eventually rolled my own psislw function that matched R's method exactly.

Happy to dig into this in more detail too, as it would be useful functionality.

@cmgoold Could you share your psislw-function?

HJA24 avatar Nov 13 '23 20:11 HJA24

@cmgoold I was just bumping this to see if you ever dropped your psislw function into the notebook and see if the number of refits dropped down to match R?

If not, would you be able to post your psislw function here so a member of the community could do it?

Thanks!

mathDR avatar Nov 16 '23 14:11 mathDR

@mathDR @HJA24 Hi both. I did not get that far in the end but I can do it. Note, however, that the issue I posted about the discrepancy seems to be due to a difference in how we should be accessing the weights in R. I'll take a look today! Just don't want to lead anyone astray.

cmgoold avatar Nov 17 '23 06:11 cmgoold

@cmgoold that would be much appreciated! I also don't want to lead anyone astray, but I believe the culprit is in the weights. If we look at the leave_future_out_cv-function of Jesse's notebook:

elif k_hat > tau or method == 'exact':
   # Exact ELPD
   ...
   elpd = compute_elpd(ll_predict)
else:
   # Approx ELPD
   ...
   elpd = compute_elpd(ll_predict, log_weights)

This would align with Jesse's findings; the result of the approximate-method is different and the exact method matches closely (weights are only used once).

HJA24 avatar Nov 17 '23 07:11 HJA24

@jessegrabowski @mathDR @HJA24

I embedded my functions into the notebook. Note, I ran into nan issues for the r_eff parameter, but do not have the time to go into details about why that is. So, for now, I set that variable to 1.0.

This is the final plot, which is pretty close in terms of ELPD between the two methods, but still quite far off the exact khats. I haven't checked this with the R implementation. Nonetheless, this model only refits once, which seems to match the R implementation?

Screenshot 2023-11-17 at 12 51 14

Here's my gist: https://gist.github.com/cmgoold/125eee0952c4905f3318de7ab2a11826

cmgoold avatar Nov 17 '23 12:11 cmgoold

I also should point out that Burkner et al.'s implementation normalises the exponentiated weights, as I do also. This doesn't look like it's accounted for in your original script.

I'm happy to keep digging away at this and would be interested in helping out putting a PR together too, if we go down that road.

cmgoold avatar Nov 17 '23 13:11 cmgoold

So I would guess when this work is completed, methods for the various wrappers would be implemented?

Hey @cmgoold any PR I could review?

mathDR avatar Jan 30 '24 20:01 mathDR