AdvancedVI.jl icon indicating copy to clipboard operation
AdvancedVI.jl copied to clipboard

Add the forward-backward Wasserstein Gaussian variational inference algorithm

Open Red-Portal opened this issue 2 months ago • 1 comments

This adds the forward-backward Wasserstein Gaussian variational inference algorithm by Diao et al.[^DBCS2023]. This algorithm minimizes the KL divergence by running proximal stochastic gradient descent in the Bures-Wasserstein space. (The metric is the Wasserstein-2 distance and the gradient is the corresponding Bures-Wasserstein gradient.) Since this is a measure-space algorithm, it tends to converge faster than BBVI/ADVI as long as the step size is well-tuned.

Adding this algorithm to AdvancedVI has been made possible by the v0.5 update. I plan to add a couple (2~3) new VI algorithms following this for v0.6.

[^DBCS2023]: Diao, M. Z., Balasubramanian, K., Chewi, S., & Salim, A. (2023). Forward-backward Gaussian variational inference via JKO in the Bures-Wasserstein space. In International Conference on Machine Learning. PMLR.

Red-Portal avatar Oct 25 '25 14:10 Red-Portal

AdvancedVI.jl documentation for PR #210 is available at: https://TuringLang.github.io/AdvancedVI.jl/previews/PR210/

github-actions[bot] avatar Oct 26 '25 04:10 github-actions[bot]

Let me ping @yebai for good measure. Hong, are you happy if I go forward with this?

Red-Portal avatar Oct 30 '25 15:10 Red-Portal