blackjax
blackjax copied to clipboard
Refactor marginal_latent_gaussian.py so it is PyTree compatible
The library has a history of not entirely working with PyTrees, see #216 for instance. We should make sure that it does (adding tests) before the first stable release.
As discussed in #441 this primarily involves refactoring mgrad. We need some sort of resolution on #77 as well.