CausalPy icon indicating copy to clipboard operation
CausalPy copied to clipboard

Bayesian IV slow for large data sets

Open NathanielF opened this issue 2 years ago • 3 comments

Just adding a note here to investigate the parameterisation options of the Bayesian IV model. I was exploring this fitting large IV designs with lots of data and many instruments for this ticket: https://github.com/pymc-labs/CausalPy/issues/229

The Bayesian IV design just takes a very long time and doesn't always give good results on large data sets. Because CausalPy hides some of the model building complexity it's harder for the user to iteratively debug and re-parameterise. I'm wondering if there is there is anything in the model specification we can do to address this.

NathanielF avatar Sep 01 '23 08:09 NathanielF

FYI @juanitorduz

I think i'm going to read up on this paper: https://www.nber.org/papers/t0204

It seems like a good example of the advantage to a Bayesian solution, but need to ensure some kind of efficiency.

NathanielF avatar Sep 01 '23 08:09 NathanielF

Just to add that the IV methods slow down the tests and doctests. So a great solution to this issue would also address that :)

drbenvincent avatar Sep 18 '23 18:09 drbenvincent

Have you tried benchmarking the model with the JAX backend? I peeked at the code and saw it was an MvNormal likelihood. I see orders-of-magnitude speedups on statespace models (also MvNormal) by switching to JAX sampler.

jessegrabowski avatar Sep 23 '23 13:09 jessegrabowski

Just wondering if any lessons were learnt from #345 @NathanielF which could be useful for this issue? If the default backend is particularly slow, then could this be worth an issue in the PyMC repo?

Or maybe the speed has improved given changes to the IV code?

drbenvincent avatar Jun 19 '24 19:06 drbenvincent

The issue for this is https://github.com/pymc-devs/pymc/issues/7348

But the specific error @NathanielF reported in his testing is that there's no JAX funcify for LKJCholeskyCov, which needs another issue. TFP has it, so it should be quite trivial to implement (famous last words)

jessegrabowski avatar Jun 19 '24 19:06 jessegrabowski

The main finding was just that the actual model fit can be quite quick with the default sampler and numpyro sampler. Honestly not a huge difference for pm.sample ~5mins for about 3000 rows of data with priors informed by the 2sls processing step.

Good prior management is important with IV regression especially where the scale of the parameters can be of different magnitude. I'd probably recommend standardising inputs in general but didn't here because I wanted to replicate the Card parameter recovery.

So broadly I think IV model fits can be achieved in reasonable time with the base model.

The main issue before was we had bundled all the ppc sampling into the model fit and hid the progress bars so it took me a minute to realise the majority of the time spent was in the posterior predictive sampling. This was greatly sped up 20mins -> 2seconds with @jessegrabowski 's Jax trick. So I think the issue is perhaps just some inefficiency in the ppc sampling with multivariate normal distributions in the base pymc instantiation...

The same slowness occurs with the prior predictive checks but is less pronounced because we only sample 500 by default.

NathanielF avatar Jun 19 '24 19:06 NathanielF