blackjax
blackjax copied to clipboard
We should be able to distribute gradient computations
In particular the implementation should be general enough that we can shard a large dataset on several machines, compute the partial gradient on each machine and combine its value before making a leapfrog step, as in https://arxiv.org/pdf/2104.14421.pdf
I am happy to pick this up!
Related: https://www.tensorflow.org/probability/examples/Distributed_Inference_with_JAX