blackjax icon indicating copy to clipboard operation
blackjax copied to clipboard

Add the control variates gradient estimator

Open rlouf opened this issue 3 years ago • 1 comments

We add the control variates gradient estimator for stochastic gradient MCMC algorithm. Control Variates require one gradient estimation on the whole dataset, which begs two questions that may be answered in subsequent PRs:

  1. Shouldn't we compute logposterior_center in a separate init function, and thus propagate a GradientState?
  2. How can we let users distribute this computation as they wish? This operation may need distributing with pmap and we should allow that.

This PR is part of an effort to port SGMCMCJAX to blackjax, see #289.

  • [ ] Make sure that the implementation is correct, and add reference to https://arxiv.org/abs/1706.05439
  • [ ] logposterior_center should be computed in an init function that is passed to the stochastic gradient MCMC kernels. This function is executed in the algorithm's init function, and the value is passed in a GradientState:
    • Doing computation in a build_gradient* function is surprising behavior for the users
    • The SVRG gradient estimator, which is a variant of CV (with updating of the control value) will need this structure. So it comes at no cost.
  • [ ] Because the gradient estimate is alway computed at the end of the diffusion we end up with one useless gradient computation. We should therefore compute the gradient in the kernels before the diffusion happens.

rlouf avatar Sep 19 '22 10:09 rlouf

Codecov Report

Merging #299 (8684689) into main (becd2d2) will decrease coverage by 0.05%. The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main     #299      +/-   ##
==========================================
- Coverage   89.79%   89.73%   -0.06%     
==========================================
  Files          45       45              
  Lines        2166     2134      -32     
==========================================
- Hits         1945     1915      -30     
+ Misses        221      219       -2     
Impacted Files Coverage Δ
blackjax/kernels.py 99.55% <100.00%> (+0.78%) :arrow_up:
blackjax/mcmc/diffusions.py 100.00% <100.00%> (ø)
blackjax/mcmc/mala.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/__init__.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/diffusions.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/gradients.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/sghmc.py 100.00% <100.00%> (ø)
blackjax/sgmcmc/sgld.py 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

codecov[bot] avatar Sep 19 '22 11:09 codecov[bot]

A few thoughts:

  • I'm on the fence when it comes to keeping CV and implementing SVRG; not only do they require to keep track of a GradientState, they also need the full dataset at initialization (CV) or at every step (SVRG) which can be prohibitive in some scenarios. I would only consider keeping them in an Optax-like API where gradients are computed outside of the integrators. Optax has a control variates API we can get inspiration from.
  • We cannot adopt an Optax-like API because some palindromic integrators (BADODAB for instance) require two gradient evaluations per step. Maybe there's something to learn from Optax's MultiStep interface?
  • Second-order methods like AMAGOLD #375 may put additional constraints on the API so we may want to sketch an implementation before moving forward.

All in all, if we didn't have methods where several gradient evaluations are needed to get one sample our life would be much easier. But also, maybe, none of that matters since complexity only affects the internals.

rlouf avatar Oct 07 '22 14:10 rlouf

High level question: here you also revert the change in https://github.com/blackjax-devs/blackjax/pull/293/ where you turn step size into a callable (so user can control how step size change during sampling), what is the reasoning behind the revert?

junpenglao avatar Oct 08 '22 06:10 junpenglao

High level question: here you also revert the change in https://github.com/blackjax-devs/blackjax/pull/293/ where you turn step size into a callable (so user can control how step size change during sampling), what is the reasoning behind the revert?

That's an important question. The current high-level interface of SgLD is:

sgld = blackjax.sgld(grad_estimator, schedule)

Or for a constant schedule

sgld = blackjax.sgld(grad_estimator, 1e-3)

And then to take a step

sgld.step(rng_key, state, minibatch)

Internally this forces us to increment a counter in the state, which I really dislike. For users this interface becomes very quickly impractical as I have seen when implementing Cyclical SgLD. Having the schedule baked in is also a common criticism of Optax. I much prefer the interface:

sgld = blackjax.sgld(grad_estimator)
...
state = sgld.step(rng_key, state, minibatch, step_size)

My ideal lower-level interface would be:

sgld = blackjax.sgld()
...
step_size = next(schedule)
minibatch = next(dataset)
gradients = grad_estimator(state, minibatch)
state = sgld.step(rng_key, state, gradients, step_size)

But I've expressed why that's difficult above.

rlouf avatar Oct 08 '22 07:10 rlouf

So for low to mid level usage, user might be able to do something like:

cosine_decay_scheduler = optax.cosine_decay_schedule(0.0001, decay_steps=total_steps, alpha=0.95)

for i in ...:  # could be in a jax.scan as well
  step_size = cosine_decay_scheduler(i)
  minibatch = ...
  gradients = grad_estimator(state, minibatch)
  state = sgld.step(rng_key, state, gradients, step_size)

junpenglao avatar Oct 08 '22 08:10 junpenglao

@junpenglao @bstaber The behavior in a loop was not very clear indeed, so let me give a full example. I modified (hopefully improved) the behavior slightly. For the simple Robbins-Monro estimator we have:

import jax
import blackjax
import blackax.sgmcmc.gradients as gradients # need an alias


schedule: Generator[float]
data = PyTree
batches: Generator[jax.numpy.DeviceArray]
position: PyTree

# Get the CV gradient estimator and SGHMC algorithm
grad_estimator = gradients.simple_estimator(logprior_fn, loglikelihood_fn, num_examples)
sghmc = blackjax.sgmcmc.sghmc()

rng_key = jax.random.PRNGKey(0)
for step in range(num_training_steps):
    _, rng_key = jax.random.split(rng_key)

    minibatch = next(batches)
    step_size = next(schedule)
    position, grad_state, info = sghmc.step(  # uses svrg.grad internally
        rng_key,
        position,
        grad_estimator,
        minibatch,
        step_size
    )

Now for the Control Variates estimator:

import jax
import blackjax
import blackax.sgmcmc.gradients as gradients # need an alias


schedule: Generator[float]
data = PyTree
batches: Generator[jax.numpy.DeviceArray]
position: PyTree
centering_position: PyTree


# Get the CV gradient estimator and SGHMC algorithm
cv = gradients.cv(logprior_fn, loglikelihood_fn, num_examples)
sghmc = blackjax.sgmcmc.sghmc()

# Initialize the gradient state
# (SGHMC state is simply the position)
grad_estimator = cv.init(centering_position, data)

rng_key = jax.random.PRNGKey(0)
for step in range(num_training_steps):
    _, rng_key = jax.random.split(rng_key)

    minibatch = next(batches)
    step_size = next(schedule)
    position, grad_state, info = sghmc.step(
        rng_key,
        position,
        grad_estimator,
        minibatch,
        step_size
    )

SVRG is a CV estimator with updates. @bstaber's intuition is correct, and we can re-use the same code as for CV; we just need to re-initialize the control variate every cv_update_rate steps:

import jax
import blackjax


schedule: Generator[float]
data = PyTree
batches: Generator[jax.numpy.DeviceArray]
position: PyTree
centering_position: PyTree
cv_update_rate: int


# Get the CV gradient estimator and SGHMC algorithm
svrg = gradients.cv(logprior_fn, loglikelihood_fn, num_examples)
sghmc = blackjax.sghmc(cv)

# Initialize the gradient state
# (SGHMC state is simply the position)
grad_estimator = svrg.init(centering_position, data)

rng_key = jax.random.PRNGKey(0)
for step in range(num_training_steps):
    _, rng_key = jax.random.split(rng_key)

    minibatch = next(batches)
    step_size = next(schedule)
    position, grad_state, info = sghmc.step(
        rng_key,
        position,
        grad_estimator,
        minibatch,
        step_size
    )

    # SVRG is nothing more than CV that you can update
    if step % == cv_update_rate:
        grad_estimator = svrg.init(centering_position, data)

While it is naively tempting to compute the gradient estimate outside of sghmc.step, SGHMC (and some other algorithms) needs to compute the gradient several times before returning a sample; the situation is not quite the same as Optax.

We may go a step further to remove the awkwardness that @junpenglao saw in the code, and make CV effectively a wrapper around the Robbins-Monro estimator:

grad_estimator = gradients.simple_estimator(logprior_fn, loglikelihood_fn, num_examples)
cv_grad_estimator = gradients.cv(grad_estimator, centering_position, data)

For SVRG, again, we need to rebuild the estimator every cv_update_rate states:

    if step % == cv_update_rate:
        cv_grad_estimator = gradients.cv(grad_estimator, position, data)

rlouf avatar Oct 21 '22 13:10 rlouf

@junpenglao This is ready for review. I implemented the control variates as a wrapper around the simple estimator, as dicussed above. The code is much cleaner than before, and svrg is obtained by calling gradients.control_variates again on the gradient estimator within the sampling loop. So we can tick this off the list in #289 as well once this PR is merged.

rlouf avatar Oct 22 '22 15:10 rlouf

@junpenglao ping

rlouf avatar Nov 09 '22 15:11 rlouf