Rémi Louf

Results 533 comments of Rémi Louf

I think this algorithm would be a nice addition to Blackjax!

The RTD build is now in place. However, there are several notebooks that fail to execute: Aesara, PyMC and Oryx. We need to fix this before completely switching to RTD....

In #392 we define a `VIAlgorithm`, and here we would need to define a new base type `ParametrizedVIAlgorithm` base type.

Great! No there is no such assumption in the library (or at least shouldn't be), we try to support PyTree states as much as we can.

As you can see with the [pathfinder implementation](https://github.com/blackjax-devs/blackjax/blob/bab42d809b48492f2cbc06471497cefbbf8a90f8/blackjax/kernels.py#L1204), Blackjax treats VI differrently from MCMC algorothms. The idea is that you first fit an approximation to the target density, and then...

MFVI is implemented [here](https://github.com/blackjax-devs/blackjax/blob/main/blackjax/vi/meanfield_vi.py) and full ranks is being implemented in https://github.com/blackjax-devs/blackjax/pull/479. The refactoring of Pathfinder is a bit involved, but up for grabs :)

Sorry for my late reaction. I think that the issue that you're facing with the covariance matrix is part of a more general discussion we're having at the library level...

Line 1049 in this file. If you want to see what kind of changes you would need to make you can check out #1414.

Sounds about right. You can look at this PR to have an idea of all the places where you will have to make changes.