Vectorize NUTS
It would be nice to have a vectorized/GPU-compatible implementation of NUTS to follow up the excellent work in #117. That PR indicates a main limitation is how to handle different depths across chains.
NumPyro and TensorFlow Probaility (used by PyMC4) have addressed this problem by developing an iterative NUTS approach. You can see their write-ups at https://github.com/pyro-ppl/numpyro/wiki/Iterative-NUTS and https://github.com/tensorflow/probability/blob/master/discussion/technical_note_on_unrolled_nuts.md.
While I haven't benchmarked TFP, there's supposedly not much of a runtime difference between running a single chain and hundreds of chains, at least for simple models. See this blog post for examples.
@junpenglao shared with me this related Stan discourse: https://discourse.mc-stan.org/t/parallel-dynamic-hmc-merits/10895/
I plan to implement this in three consecutive PRs, each with their own release.
- Reimplement NUTS to be iterative. The implementation should be equivalent to the recursive implementation, so that passing the same RNG to both versions produces identical draws and stats. It should also be at least as performant as recursive NUTS and also as readable. If these three criteria are met, then it can replace the current recursive NUTS implementation. Performing these checks will be a lot easier if there's an official benchmark for this package.
- Vectorize the iterative NUTS implementation. This is "vectorize" in the sense of "use broadcasting everywhere". This will be patterned off the existing Static HMC vectorization. However, I'll also profile using KernelAbstractions. In the former case, this will require updating the
BinaryTree, etc representations to represent a tree across all chains, and handling control flow across a vector of directions (left/right). - Ensure the vectorized NUTS implementation works well on the GPU. This will need to come after @xukai92's fixes for Static HMC sampling on the GPU.