Add the control variates gradient estimator
We add the control variates gradient estimator for stochastic gradient MCMC algorithm. Control Variates require one gradient estimation on the whole dataset, which begs two questions that may be answered in subsequent PRs:
- Shouldn't we compute
logposterior_centerin a separateinitfunction, and thus propagate aGradientState? - How can we let users distribute this computation as they wish? This operation may need distributing with
pmapand we should allow that.
This PR is part of an effort to port SGMCMCJAX to blackjax, see #289.
- [ ] Make sure that the implementation is correct, and add reference to https://arxiv.org/abs/1706.05439
- [ ]
logposterior_centershould be computed in aninitfunction that is passed to the stochastic gradient MCMC kernels. This function is executed in the algorithm'sinitfunction, and the value is passed in aGradientState:- Doing computation in a
build_gradient*function is surprising behavior for the users - The SVRG gradient estimator, which is a variant of CV (with updating of the control value) will need this structure. So it comes at no cost.
- Doing computation in a
- [ ] Because the gradient estimate is alway computed at the end of the diffusion we end up with one useless gradient computation. We should therefore compute the gradient in the
kernels before the diffusion happens.
Codecov Report
Merging #299 (8684689) into main (becd2d2) will decrease coverage by
0.05%. The diff coverage is100.00%.
@@ Coverage Diff @@
## main #299 +/- ##
==========================================
- Coverage 89.79% 89.73% -0.06%
==========================================
Files 45 45
Lines 2166 2134 -32
==========================================
- Hits 1945 1915 -30
+ Misses 221 219 -2
| Impacted Files | Coverage Δ | |
|---|---|---|
| blackjax/kernels.py | 99.55% <100.00%> (+0.78%) |
:arrow_up: |
| blackjax/mcmc/diffusions.py | 100.00% <100.00%> (ø) |
|
| blackjax/mcmc/mala.py | 100.00% <100.00%> (ø) |
|
| blackjax/sgmcmc/__init__.py | 100.00% <100.00%> (ø) |
|
| blackjax/sgmcmc/diffusions.py | 100.00% <100.00%> (ø) |
|
| blackjax/sgmcmc/gradients.py | 100.00% <100.00%> (ø) |
|
| blackjax/sgmcmc/sghmc.py | 100.00% <100.00%> (ø) |
|
| blackjax/sgmcmc/sgld.py | 100.00% <100.00%> (ø) |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
A few thoughts:
- I'm on the fence when it comes to keeping CV and implementing SVRG; not only do they require to keep track of a
GradientState, they also need the full dataset at initialization (CV) or at every step (SVRG) which can be prohibitive in some scenarios. I would only consider keeping them in an Optax-like API where gradients are computed outside of the integrators. Optax has a control variates API we can get inspiration from. - We cannot adopt an Optax-like API because some palindromic integrators (BADODAB for instance) require two gradient evaluations per step. Maybe there's something to learn from Optax's
MultiStepinterface? - Second-order methods like AMAGOLD #375 may put additional constraints on the API so we may want to sketch an implementation before moving forward.
All in all, if we didn't have methods where several gradient evaluations are needed to get one sample our life would be much easier. But also, maybe, none of that matters since complexity only affects the internals.
High level question: here you also revert the change in https://github.com/blackjax-devs/blackjax/pull/293/ where you turn step size into a callable (so user can control how step size change during sampling), what is the reasoning behind the revert?
High level question: here you also revert the change in https://github.com/blackjax-devs/blackjax/pull/293/ where you turn step size into a callable (so user can control how step size change during sampling), what is the reasoning behind the revert?
That's an important question. The current high-level interface of SgLD is:
sgld = blackjax.sgld(grad_estimator, schedule)
Or for a constant schedule
sgld = blackjax.sgld(grad_estimator, 1e-3)
And then to take a step
sgld.step(rng_key, state, minibatch)
Internally this forces us to increment a counter in the state, which I really dislike. For users this interface becomes very quickly impractical as I have seen when implementing Cyclical SgLD. Having the schedule baked in is also a common criticism of Optax. I much prefer the interface:
sgld = blackjax.sgld(grad_estimator)
...
state = sgld.step(rng_key, state, minibatch, step_size)
My ideal lower-level interface would be:
sgld = blackjax.sgld()
...
step_size = next(schedule)
minibatch = next(dataset)
gradients = grad_estimator(state, minibatch)
state = sgld.step(rng_key, state, gradients, step_size)
But I've expressed why that's difficult above.
So for low to mid level usage, user might be able to do something like:
cosine_decay_scheduler = optax.cosine_decay_schedule(0.0001, decay_steps=total_steps, alpha=0.95)
for i in ...: # could be in a jax.scan as well
step_size = cosine_decay_scheduler(i)
minibatch = ...
gradients = grad_estimator(state, minibatch)
state = sgld.step(rng_key, state, gradients, step_size)
@junpenglao @bstaber The behavior in a loop was not very clear indeed, so let me give a full example. I modified (hopefully improved) the behavior slightly. For the simple Robbins-Monro estimator we have:
import jax
import blackjax
import blackax.sgmcmc.gradients as gradients # need an alias
schedule: Generator[float]
data = PyTree
batches: Generator[jax.numpy.DeviceArray]
position: PyTree
# Get the CV gradient estimator and SGHMC algorithm
grad_estimator = gradients.simple_estimator(logprior_fn, loglikelihood_fn, num_examples)
sghmc = blackjax.sgmcmc.sghmc()
rng_key = jax.random.PRNGKey(0)
for step in range(num_training_steps):
_, rng_key = jax.random.split(rng_key)
minibatch = next(batches)
step_size = next(schedule)
position, grad_state, info = sghmc.step( # uses svrg.grad internally
rng_key,
position,
grad_estimator,
minibatch,
step_size
)
Now for the Control Variates estimator:
import jax
import blackjax
import blackax.sgmcmc.gradients as gradients # need an alias
schedule: Generator[float]
data = PyTree
batches: Generator[jax.numpy.DeviceArray]
position: PyTree
centering_position: PyTree
# Get the CV gradient estimator and SGHMC algorithm
cv = gradients.cv(logprior_fn, loglikelihood_fn, num_examples)
sghmc = blackjax.sgmcmc.sghmc()
# Initialize the gradient state
# (SGHMC state is simply the position)
grad_estimator = cv.init(centering_position, data)
rng_key = jax.random.PRNGKey(0)
for step in range(num_training_steps):
_, rng_key = jax.random.split(rng_key)
minibatch = next(batches)
step_size = next(schedule)
position, grad_state, info = sghmc.step(
rng_key,
position,
grad_estimator,
minibatch,
step_size
)
SVRG is a CV estimator with updates. @bstaber's intuition is correct, and we can re-use the same code as for CV; we just need to re-initialize the control variate every cv_update_rate steps:
import jax
import blackjax
schedule: Generator[float]
data = PyTree
batches: Generator[jax.numpy.DeviceArray]
position: PyTree
centering_position: PyTree
cv_update_rate: int
# Get the CV gradient estimator and SGHMC algorithm
svrg = gradients.cv(logprior_fn, loglikelihood_fn, num_examples)
sghmc = blackjax.sghmc(cv)
# Initialize the gradient state
# (SGHMC state is simply the position)
grad_estimator = svrg.init(centering_position, data)
rng_key = jax.random.PRNGKey(0)
for step in range(num_training_steps):
_, rng_key = jax.random.split(rng_key)
minibatch = next(batches)
step_size = next(schedule)
position, grad_state, info = sghmc.step(
rng_key,
position,
grad_estimator,
minibatch,
step_size
)
# SVRG is nothing more than CV that you can update
if step % == cv_update_rate:
grad_estimator = svrg.init(centering_position, data)
While it is naively tempting to compute the gradient estimate outside of sghmc.step, SGHMC (and some other algorithms) needs to compute the gradient several times before returning a sample; the situation is not quite the same as Optax.
We may go a step further to remove the awkwardness that @junpenglao saw in the code, and make CV effectively a wrapper around the Robbins-Monro estimator:
grad_estimator = gradients.simple_estimator(logprior_fn, loglikelihood_fn, num_examples)
cv_grad_estimator = gradients.cv(grad_estimator, centering_position, data)
For SVRG, again, we need to rebuild the estimator every cv_update_rate states:
if step % == cv_update_rate:
cv_grad_estimator = gradients.cv(grad_estimator, position, data)
@junpenglao This is ready for review. I implemented the control variates as a wrapper around the simple estimator, as dicussed above. The code is much cleaner than before, and svrg is obtained by calling gradients.control_variates again on the gradient estimator within the sampling loop. So we can tick this off the list in #289 as well once this PR is merged.
@junpenglao ping