dbarts icon indicating copy to clipboard operation
dbarts copied to clipboard

Convergence diagnostics?

Open timdisher opened this issue 6 years ago • 5 comments

Is there a standard way to assess convergence of sampler? I have tried to find chains of estimates for trees to feed into something like posterior but am not having much luck.

timdisher avatar Mar 10 '20 16:03 timdisher

There isn't an easy way to assess the convergence of the trees themselves, but you can run any standard diagnostic on any quantity of interest. In a causal setting we often compare the convergence of average treatment effects across chains. For continuous responses, there is also the nuisance parameter of the residual standard deviation.

In all cases, you can keep the chain information by calling bart with combinechains = FALSE (or bart2 with default settings). If you call extract on a fitted model, you can recover that information if it was previously discarded.

What tends to not converge very well are the posteriors of individual predictions. That's led us to increase the number of chains, and also in the future might lead to down-weighting certain chains or sampling across them during warmup.

vdorie avatar Mar 10 '20 22:03 vdorie

Thank you, this is very helpful. Do I understand correctly that you're suggesting in the case of a straight predictive model, it make sense to just assess the convergence of the predictions then (yhats)? This was the solution I've landed in the interim!

timdisher avatar Mar 10 '20 22:03 timdisher

More or less, but don't be too surprised if some chains look a bit weird when you look at a maximum of n R.hats. It would also be possible to target the log-posterior, or any average across a unit of analysis (for example groups). I've thought a bit about how to get better mixing but haven't had time to implement anything yet.

vdorie avatar Mar 11 '20 19:03 vdorie

Can anyone suggest how this sort of check might be done, even if it were only on the yhats? I've been looking at, for example, using the rstan convergence tools (Rhat, ess_bulk, etc.). If I have a yhat.train matrix of uncombined chains that is 4 chains by 500 draws by 1,000 yhats, how would that transpose or translate into the sort of matrix these tools are looking for? In its natural form we have an array that would seem like it has to be simplified in some manner. Thanks!

bachlaw avatar Sep 10 '21 16:09 bachlaw

With that said, I suspect one can infer probable stationarity of the posterior by doing a train/test split of the data, and scoring various combinations of burnin and saved samples until you reach maximum out of sample accuracy and also stop seeing further improvements with more of either.

bachlaw avatar Sep 10 '21 17:09 bachlaw