arviz
arviz copied to clipboard
Leave Future Out Cross Validation
Hello,
I am interested in adapting the refitting wrappers to implement LFO-CV, as described by Bürkner, Gabry, and Vehtari (2020), with the goal of cleaning everything up, writing unit tests, and submitting a PR. I am following the R implementation the authors provide in the paper's Github repo.
I have a somewhat working prototype, which I put into a notebook here. The "exact" method, where the model is completely re-fit every time step, matches the R output very closely, so I am feeling pretty good that I am roughly on the right track. The much more important approximate method, however, seems to have a problem. The k-hat parameters seem to be much too high, and as a result the model re-fits too many times. The R implementation requires only a single re-fit for the data I present in the notebook, whereas my implementation requires more than a dozen.
I am positive I have made a mistake in my computations, and I suspect it is in the psis_log_weights
function:
def psis_log_weights(ll_predict):
# TODO: Support user-supplied name on time dim
# Equation 10
log_ratio = ll_predict.sum('time')
lr_stacked = log_ratio.stack(__sample__ = ['chain', 'draw'])
# Method "identity" gives closest match to loo::psis output in R, is it right?
reff = az.ess(1 / np.exp(log_ratio), method="identity", relative=True).x.values
# TODO: loo::psis doesn't need lr_stacked to be negative? See:
# https://github.com/paul-buerkner/LFO-CV-paper/blob/master/sim_functions.R#L231
log_weights, k_hat = az.psislw(-lr_stacked, reff)
return log_weights, k_hat
I am worried I'm not familiar enough with the required inputs to az.ess
and az.psislw
, as well as how this differ from the corresponding functions in the R loo
package, to see where I am going wrong. I am hoping the community here will spot something dumb I did right away.
Thanks for any help you can give!
Hi, have you tried the good old 'print all steps' in R and Python?
Haha yeah, I stepped through them both in debug mode side-by-side and took notes. That's the only way I got this far.
I have a suspicion I'm being a knucklehead by trying to follow the R code too closely. The places I'm confused are:
- When I compute the
reff
, I use the un-stacked log ratio (becauseaz.ess
expects a chain dimension), whileaz.pisislw
expects chain and draw to be stacked into sample. Does this create inconsistencies? - It's not super clear to me what the "method" argument in
az.ess
is doing, and there doesn't seem to be an equivalent argument in e.g.loo::relative_eff
. I went with "identity" just because it made the outputs match, but not for any principled reason. - Where is this 1 / np.exp(log_ratio) coming from? I just blindly copied it from R because it makes the log_weights match more closely. But in
az.loo
there's nothing of the sort for computing the relative ess. Here's the original R code, in both cases logratio is a(chains * samples, 1)
vector made by summing the out-of-sample log-likelihoods across the time dimension.
r_eff_lr <- loo::relative_eff(logratio, chain_id = chain_id)
r_eff <- loo::relative_eff(1 / exp(logratio), chain_id = chain_id)
- For that matter,
az.loo
doesn't use therelative=True
parameter -- is there any reason for that? -
az.psislw
requires a negative log-likelihood, butloo::psis
does not?
Hi, thanks for getting this rolling @jessegrabowski! Trying to catch up.
-
There should not be any inconsistencies due to that.
ess
and the different methods are tested quite thoroughly too I believe also with comparisons to R -
The
method
argument defines how the effective sample size is computed. There are different ways to go about this, such as "splitting" the 4 chains/1000 draws into 8 chains/500 draws, or using ranks to compute the ess instead of raw values (see more details for example on https://arxiv.org/abs/1903.08008, also, the different methods match different functions in the posterior package).After a 1 min skim of the
loo::relative_eff
function it looks like the computations there match the identity method (aka no split, no ranks...). This also makes sense conceptually as I assume here we don't really want to use ess as a diagnostic but instead assume we have samples from a converged fit and was an estimation of relative ess as precise as possible. -
IIRC, the ress parameter in psis is there to precalculate some constants and make psis more efficient. I think ArviZ currently takes the mean over all ess from the posterior variables. It is probably a better idea to use the ess directly for the quantity on which we will use psis which seems to be what is happening here.
-
(6 can't get the number in the preview to be a 6) I don't think there is any reason for that. It might even be because loo was written before
ess
had arelative
argument and it wasn't updated after that. -
(7) I haven't yet gone through all the code in LFO in detail, but from the docs both
az.psislw
andloo::psis
behave similarly. Their expected input is thelog_ratios
which for "standard" loo-psis are the negative pointwise log likelihood values. If the log ratios are defined diferently here then the negative might not be necessary anymore. Ref: http://mc-stan.org/loo/reference/psis.html#arguments, quote:An array, matrix, or vector of importance ratios on the log scale (for PSIS-LOO these are negative log-likelihood values).
so
az.loo
passes-log_lik
as log_ratios to psislw, butpsislw
can be used to perform pareto smoothed importance sampling on any array of log ratios, that "constraint" is part of loo not of psis which is why it is done inaz.loo
and not inaz.psislw
.
Will try to go over the notebook at some point this week. Please continue asking more specific questions here if you want me to focus on something specific.
@jessegrabowski I wonder if this is an issue related to #2148 I posted a while ago.
I implemented PSIS-LFO-CV a while ago at work, and eventually rolled my own psislw
function that matched R's method exactly.
Happy to dig into this in more detail too, as it would be useful functionality.
Could be! I've let this project fall to the wayside, but I need to come back to it in the coming weeks/months, so I'm keen to collaborate on it. Could you have a look at that notebook I posted and see if anything strikes you as obviously wrong? Maybe try it your modified psislw
function to see if that reduces the number of refits to match R? I think it was only 2-3, but it's been a while.
@jessegrabowski Yes, I can take a look!
@jessegrabowski I wonder if this is an issue related to #2148 I posted a while ago.
I implemented PSIS-LFO-CV a while ago at work, and eventually rolled my own
psislw
function that matched R's method exactly.Happy to dig into this in more detail too, as it would be useful functionality.
@cmgoold Could you share your psislw
-function?
@cmgoold I was just bumping this to see if you ever dropped your psislw
function into the notebook and see if the number of refits dropped down to match R?
If not, would you be able to post your psislw
function here so a member of the community could do it?
Thanks!
@mathDR @HJA24 Hi both. I did not get that far in the end but I can do it. Note, however, that the issue I posted about the discrepancy seems to be due to a difference in how we should be accessing the weights in R. I'll take a look today! Just don't want to lead anyone astray.
@cmgoold that would be much appreciated! I also don't want to lead anyone astray, but I believe the culprit is in the weights.
If we look at the leave_future_out_cv
-function of Jesse's notebook:
elif k_hat > tau or method == 'exact':
# Exact ELPD
...
elpd = compute_elpd(ll_predict)
else:
# Approx ELPD
...
elpd = compute_elpd(ll_predict, log_weights)
This would align with Jesse's findings; the result of the approximate-method is different and the exact method matches closely (weights are only used once).
@jessegrabowski @mathDR @HJA24
I embedded my functions into the notebook. Note, I ran into nan
issues for the r_eff
parameter, but do not have the time to go into details about why that is. So, for now, I set that variable to 1.0
.
This is the final plot, which is pretty close in terms of ELPD between the two methods, but still quite far off the exact khats
. I haven't checked this with the R implementation. Nonetheless, this model only refits once, which seems to match the R implementation?
Here's my gist: https://gist.github.com/cmgoold/125eee0952c4905f3318de7ab2a11826
I also should point out that Burkner et al.'s implementation normalises the exponentiated weights, as I do also. This doesn't look like it's accounted for in your original script.
I'm happy to keep digging away at this and would be interested in helping out putting a PR together too, if we go down that road.
So I would guess when this work is completed, methods for the various wrappers would be implemented?
Hey @cmgoold any PR I could review?