blackjax icon indicating copy to clipboard operation
blackjax copied to clipboard

We should be able to distribute gradient computations

Open rlouf opened this issue 3 years ago • 2 comments

In particular the implementation should be general enough that we can shard a large dataset on several machines, compute the partial gradient on each machine and combine its value before making a leapfrog step, as in https://arxiv.org/pdf/2104.14421.pdf

rlouf avatar Jun 29 '22 09:06 rlouf

I am happy to pick this up!

ludgerpaehler avatar Jan 13 '23 16:01 ludgerpaehler

Related: https://www.tensorflow.org/probability/examples/Distributed_Inference_with_JAX

junpenglao avatar Jan 13 '23 16:01 junpenglao