blackjax icon indicating copy to clipboard operation
blackjax copied to clipboard

Adding some basic VI approximation and fitting routine

Open junpenglao opened this issue 3 years ago • 9 comments

Copying over from https://github.com/blackjax-devs/blackjax/pull/392#discussion_r1020745315

After #392, we should add the 2 most basic VI algorithm: meanfield and full rank ADVI [1]. Below is a working example of Meanfield ADVI:

import jax
import jax.numpy as jnp
from jax.scipy import stats

def gen_meanfield_logprob(params):
    mu_param, rho_param = params
    sigma_param = jax.tree_map(jnp.exp, rho_param)
    def meanfield_logprob(position):
        logq_pytree = jax.tree_map(
            stats.norm.logpdf, position, mu_param, sigma_param
            )
        logq = jax.tree_map(jnp.sum, logq_pytree)
        return jax.tree_util.tree_reduce(jnp.add, logq)
    return meanfield_logprob

# gen_meanfield_logprob(init_params)(init_position)

def meanfield_sample(
    rng_key, meanfield_param, num_samples: int
    ):
    if not isinstance(num_samples, tuple):
        num_samples = (num_samples,)
    mu_param, rho_param = meanfield_param
    sigma_param = jax.tree_map(jnp.exp, rho_param)
    mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu_param)
    sigma_flatten, _ = jax.flatten_util.ravel_pytree(sigma_param)
    flatten_sample = jax.random.normal(
        rng_key, num_samples + mu_flatten.shape
        ) * sigma_flatten + mu_flatten
    if len(num_samples) == 0:
        return unravel_fn(flatten_sample)
    return jax.vmap(unravel_fn)(flatten_sample)

# meanfield_sample(rng, init_params, ())

def meanfield_approximate(rng, init_params, log_prob_fn, optimizer, sample_size=5, num_steps=200):
    def meanfield_approximate_step(
        state, rng_key_sample
        ):
        params, opt_state = state
        def kl_fn(params):
            sample = meanfield_sample(rng_key_sample, params, sample_size)
            logq = gen_meanfield_logprob(params)(sample)
            logp = log_prob_fn(sample)
            return (logq - logp).mean()
        # compute KL divergence
        elbo, grad = jax.value_and_grad(kl_fn)(params)
        updates, opt_state = optimizer.update(grad, opt_state, params)
        params = jax.tree_map(
            lambda p, u: p + u, params, updates
            )
        return (params, opt_state), elbo
    
    def run_optimization(init_params):
        opt_state = optimizer.init(init_params)
        state = (init_params, opt_state)
        rng_key = jax.random.split(rng, num_steps)
        return jax.lax.scan(
            meanfield_approximate_step, state, rng_key
            )
    
    return run_optimization(init_params)

Fitting a model looks like:

import matplotlib.pyplot as plt
import numpy as np

import optax
import tensorflow_probability.substrates.jax as tfp

tfd = tfp.distributions

rng = jax.random.PRNGKey(0)

seed0, seed1, rng = jax.random.split(rng, 3)
X = jax.random.normal(seed0, (100, 98))
y = X @ np.arange(98) + jax.random.normal(seed1, (100,))

@tfd.JointDistributionCoroutineAutoBatched
def model():
    sigma = yield tfd.HalfNormal(5.0, name='sigma')
    mu = yield tfd.Normal(0.0, 1.0, name='mu')
    beta = yield tfd.Sample(tfd.Normal(mu, sigma), X.shape[-1], name='beta')
    yield tfd.Normal(X @ beta, 1.0, name="y")

# init_position = model.sample(seed=rng)
pinned = model.experimental_pin(y=y)
init_position = pinned.sample_unpinned(seed=rng)

bijectors = pinned.experimental_default_event_space_bijector()
def log_prob_fn(unbound_param):
    param = bijectors.forward(unbound_param)
    log_det_jacobian = bijectors.forward_log_det_jacobian(unbound_param)
    return pinned.unnormalized_log_prob(param) + log_det_jacobian
# This is just one way to do it. We could also use a flattened array to represent mu and rho
mu_param = jax.tree_map(jnp.zeros_like, init_position)
rho_param = jax.tree_map(jnp.zeros_like, init_position)
init_params = (mu_param, rho_param)

optimizer = optax.chain(optax.clip(10.), optax.adam(1.))
output = meanfield_approximate(rng, init_params, log_prob_fn, optimizer)

[1] https://arxiv.org/abs/1603.00788

junpenglao avatar Nov 12 '22 14:11 junpenglao

In #392 we define a VIAlgorithm, and here we would need to define a new base type ParametrizedVIAlgorithm base type.

rlouf avatar Nov 12 '22 15:11 rlouf

@rlouf I think I can take charge of this, but just for sure:

We assume the log_prob_fn (i.e. log p(x,z) ) in BlackJax takes in a real value flattened array (rather than a dict or something on constrained space) right

xidulu avatar Dec 14 '22 20:12 xidulu

Great! No there is no such assumption in the library (or at least shouldn't be), we try to support PyTree states as much as we can.

rlouf avatar Dec 14 '22 21:12 rlouf

@rlouf

To follow the design principle of blackjax, I believe VI should also has an API of the form below? :

new_state, info =  kernel(rng_key, state)

which would perform one optimization step for the ELBO.

xidulu avatar Dec 26 '22 05:12 xidulu

As you can see with the pathfinder implementation, Blackjax treats VI differrently from MCMC algorothms.

The idea is that you first fit an approximation to the target density, and then sample from this approximation with something like (in peudo-code):

approx, info = approximate(rng_key, position)
samples = sample(sample_key, approx, num_samples)

I think at the higher-level the API will always be something more or less like this. We can consider a kernel-like lower interface for some algorithms if it makes sense. But again, I am no VI expert and open to suggestions.

rlouf avatar Dec 26 '22 07:12 rlouf

Can someone then, give a minimal working example for the Mean Field VI? This would be helpful also for the refactoring of the pathfinder API in #465 and the implementation of the full rank approach, i believe.

LarsKarbach avatar Jan 29 '23 09:01 LarsKarbach

MFVI is implemented here and full ranks is being implemented in https://github.com/blackjax-devs/blackjax/pull/479. The refactoring of Pathfinder is a bit involved, but up for grabs :)

rlouf avatar Jan 29 '23 15:01 rlouf

I understand. Altough i would argue that it would be helpful to get a foot in the door, if one wants to help to develop VI further. Could be as easy as having a multivariate normal and evaluating mean field and full rank. For the current implementation i don't see immediately how the pseudo-code you provided is implemented in the library.

LarsKarbach avatar Jan 30 '23 19:01 LarsKarbach

@LarsKarbach I understand your point and I really wish there could be a template for implementing VI variants (e.g. as simple as providing a log_q function and a sampling function) but the APIs are still in the very initial stage. At this moment, there are still lots of boilerplate code in the implementation... Probably after the fullrank VI's PR got merged in, we could start working on simplifying the VI implementation process.

xidulu avatar Jan 30 '23 19:01 xidulu