CausalPy
CausalPy copied to clipboard
Bayesian IV slow for large data sets
Just adding a note here to investigate the parameterisation options of the Bayesian IV model. I was exploring this fitting large IV designs with lots of data and many instruments for this ticket: https://github.com/pymc-labs/CausalPy/issues/229
The Bayesian IV design just takes a very long time and doesn't always give good results on large data sets. Because CausalPy hides some of the model building complexity it's harder for the user to iteratively debug and re-parameterise. I'm wondering if there is there is anything in the model specification we can do to address this.
FYI @juanitorduz
I think i'm going to read up on this paper: https://www.nber.org/papers/t0204
It seems like a good example of the advantage to a Bayesian solution, but need to ensure some kind of efficiency.
Just to add that the IV methods slow down the tests and doctests. So a great solution to this issue would also address that :)
Have you tried benchmarking the model with the JAX backend? I peeked at the code and saw it was an MvNormal likelihood. I see orders-of-magnitude speedups on statespace models (also MvNormal) by switching to JAX sampler.
Just wondering if any lessons were learnt from #345 @NathanielF which could be useful for this issue? If the default backend is particularly slow, then could this be worth an issue in the PyMC repo?
Or maybe the speed has improved given changes to the IV code?
The issue for this is https://github.com/pymc-devs/pymc/issues/7348
But the specific error @NathanielF reported in his testing is that there's no JAX funcify for LKJCholeskyCov, which needs another issue. TFP has it, so it should be quite trivial to implement (famous last words)
The main finding was just that the actual model fit can be quite quick with the default sampler and numpyro sampler. Honestly not a huge difference for pm.sample ~5mins for about 3000 rows of data with priors informed by the 2sls processing step.
Good prior management is important with IV regression especially where the scale of the parameters can be of different magnitude. I'd probably recommend standardising inputs in general but didn't here because I wanted to replicate the Card parameter recovery.
So broadly I think IV model fits can be achieved in reasonable time with the base model.
The main issue before was we had bundled all the ppc sampling into the model fit and hid the progress bars so it took me a minute to realise the majority of the time spent was in the posterior predictive sampling. This was greatly sped up 20mins -> 2seconds with @jessegrabowski 's Jax trick. So I think the issue is perhaps just some inefficiency in the ppc sampling with multivariate normal distributions in the base pymc instantiation...
The same slowness occurs with the prior predictive checks but is less pronounced because we only sample 500 by default.