Testing strategy

ciguaran opened this issue 1 year ago

The goal of this issue is to kickoff the discussion on the testing strategy. Hopefully we can unify names/concepts.

Our current testing landscape

  • E2E tests: they run a complete sampling over some model and assert over the shape of the result distribution. This may requiere several kernel steps.
    • Advantages:
      • checks that the sampling works for a simple model. All pieces of the model ‘fit’, in the sense that inputs of say the proposal_generator can be used by some other component like trajectory builders, etc.
      • checks that the output of one step can be input for the following one
      • checks properties that may only hold after many steps are taken (like stationarity?).
    • Disadvantages:
      • slow to run
      • if some regression is included, is difficult to find where it happen (slow developer feedback).
      • as ‘ifs’ are added, it becomes imposible to test the combination.
  • Integration tests: they verify that some components (for example, 2) can be used together. Without reaching user-level full grouping of objects. This means that they don’t group enough objects so as to be something the user would use directly like a full kernel.
    • For example, running one step of a kernel.
      • Some math properties may be checked (such as invariance of the distribution).
      • Some code properties may be checked: pieces work together correctly.
  • Unit tests: they verify the behavior of a single component. Small tests that run fast.

Some of the advantages/disadvantages listed above, are specific to our domain, (like the fact that certain properties of the output are only valid after several steps of execution), but some others are typical characteristics of E2E/big tests.

Our desired testing landscape

It is generally considered that a healthy codebase test suite has pyramid shape [1,2]: has a bigger set of unit tests, quick developer feedback, easy to understand, easy to catch regressions, complemented with a smaller set of integration tests that verify the connections between objects and finally a smaller set of big, slow, but user related E2E tests.

In systems design, it is always important to identify and differentiate between what Brooks [4] coined accidental complexity and essential complexity. Accidental complexity has to do with details specific to solving a problem, when executing in a computer. For example, the fact that blackjax runs on top of jax, and the code needs to be jittable for performance. This is complementary to essential complexity, which is proper of the problem being solved: for example, the fact that in order to run HMC we need to differentiate, or the fact that we need to sample to implement RMH acceptance rule. This is closely related to the concept of `domain logic’: the set of ideas that we want our code to represent, and the operations related to them. In our case, we are talking about the components to build samplers, and how they interoperate. We would like to be able clearly represent in code samplers, and thus be able to test them in a simple and fast way, trying to isolate the essential complexity from the accidental complexity.

Theory to practice

All these looks great in theory, but in practice Blackjax is a library with a heavy dependence on JAX. We want all our code to be Jittable and runnable in devices such as GPU, so we need to balance testing the domain logic while making sure certain accidental properties hold.

Quoting from [5]:

JAX relies extensively on code transformation and compilation, meaning that it can be hard to ensure that code is properly tested. For instance, just testing a python function using JAX code will not cover the actual code path that is executed when jitted, and that path will also differ whether the code is jitted for CPU, GPU, or TPU. This has been a source of obscure and hard to catch bugs where XLA changes would lead to undesirable behaviours that however only manifest in one specific code transformation.

In order to account for this situation, the Jax ecosystem provides Chex, which among other things, exposes chex.TestCase that forces code compilation and emulates the presence/absence of devices such as GPU. Since the compilation is JIT, JAX requires us to run the code on JAX compilable inputs, in order to extract the shapes, and then compile the code given those shapes. This has as consequence that Test Doubles [3] libraries like MagicMock can't be used at all in chex test, since they are not JAX-compilable.

A real example

Let’s try to analize this code, which corresponds to the NUTS kernel.

def kernel(
    integrator: Callable = integrators.velocity_verlet,
    divergence_threshold: int = 1000,
    max_num_doublings: int = 10,
    def one_step(
        rng_key: PRNGKey,
        state: hmc.HMCState,
        logdensity_fn: Callable,
        step_size: float,
        inverse_mass_matrix: Array,
    ) -> Tuple[hmc.HMCState, NUTSInfo]:
        """Generate a new sample with the NUTS kernel."""

        ) = metrics.gaussian_euclidean(inverse_mass_matrix)
        symplectic_integrator = integrator(logdensity_fn, kinetic_energy_fn)
        proposal_generator = iterative_nuts_proposal(

        key_momentum, key_integrator = jax.random.split(rng_key, 2)

        position, logdensity, logdensity_grad = state
        momentum = momentum_generator(key_momentum, position)

        integrator_state = integrators.IntegratorState(
            position, momentum, logdensity, logdensity_grad
        proposal, info = proposal_generator(key_integrator, integrator_state, step_size)
        proposal = hmc.HMCState(
            proposal.position, proposal.logdensity, proposal.logdensity_grad
        return proposal, info

    return one_step

Let’s start by noting that kernel is factory: it builds another object, in this case, the one_step function. But that’s not the only responsibility it has(I am talking about the SOLID notion of responsibility here), it also has one_step defined inside it, so all the behavior from one_step is also defined inside kernel. Now if we think about one_step, its main goal should be, given an input, execute one nuts step. But here, again, there are more responsibilities inside:

  1. choosing gaussian_euclidean as metric
  2. choosing nuts_proposal as proposal
  3. applying the step

The main problem with this code from a testing standpoint is that all listed here ends up being tested together, there’s no way of decoupling the step test from an execution of the gaussian_euclidean code. Naturally, this leads to bigger tests, closer to E2E than to unitary. Searching for usages of this code, it seems is being used in test_compilation and test_benchmarks. If we remove those two, the file has 45% test coverage.

Design suggestion 1.

In this case, a suggestion is splitting between building/factory code vs properly sampling code.

# Copyright 2020- The Blackjax Authors.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""Public API for the NUTS Kernel"""
import functools
from typing import Callable, NamedTuple, Tuple

import jax
import jax.numpy as jnp
import numpy as np

import blackjax.mcmc.hmc as hmc
import blackjax.mcmc.integrators as integrators
import blackjax.mcmc.metrics as metrics
import blackjax.mcmc.proposal as proposal
import blackjax.mcmc.termination as termination
import blackjax.mcmc.trajectory as trajectory
from blackjax.types import Array, PRNGKey, PyTree

__all__ = ["NUTSInfo", "init", "build_kernel"]

init = hmc.init

class NUTSInfo(NamedTuple):
    """Additional information on the NUTS transition.

    This additional information can be used for debugging or computing

        The momentum that was sampled and used to integrate the trajectory.
        Whether the difference in energy between the original and the new state
        exceeded the divergence threshold.
        Whether the sampling returned because the trajectory started turning
        back on itself.
        Energy of the transition.
        The leftmost state of the full trajectory.
        The rightmost state of the full trajectory.
        Number of subtrajectory samples that were taken.
        Number of integration steps that were taken. This is also the number of
        states in the full trajectory.
        average acceptance probabilty across entire trajectory


    momentum: PyTree
    is_divergent: bool
    is_turning: bool
    energy: float
    trajectory_leftmost_state: integrators.IntegratorState
    trajectory_rightmost_state: integrators.IntegratorState
    num_trajectory_expansions: int
    num_integration_steps: int
    acceptance_rate: float

def propose_from_momentum(rng_key, momentum_generator, proposal_generator, state, step_size):
    key_momentum, key_integrator = jax.random.split(rng_key, 2)

    position, logdensity, logdensity_grad = state
    momentum = momentum_generator(key_momentum, position)

    integrator_state = integrators.IntegratorState(
        position, momentum, logdensity, logdensity_grad
    proposal, info = proposal_generator(key_integrator, integrator_state, step_size)
    proposal = hmc.HMCState(
        proposal.position, proposal.logdensity, proposal.logdensity_grad
    return proposal, info

def build_kernel(
        integrator: Callable = integrators.velocity_verlet,
        divergence_threshold: int = 1000,
        max_num_doublings: int = 10,
    """Build an iterative NUTS kernel.

    This algorithm is an iteration on the original NUTS algorithm :cite:p:`hoffman2014no`
    with two major differences:

    - We do not use slice samplig but multinomial sampling for the proposal
    - The trajectory expansion is not recursive but iterative :cite:p:`phan2019composable`,

    The implementation can seem unusual for those familiar with similar
    algorithms. Indeed, we do not conceptualize the trajectory construction as
    building a tree. We feel that the tree lingo, inherited from the recursive
    version, is unnecessarily complicated and hides the more general concepts
    upon which the NUTS algorithm is built.

    NUTS, in essence, consists in sampling a trajectory by iteratively choosing
    a direction at random and integrating in this direction a number of times
    that doubles at every step. From this trajectory we continuously sample a
    proposal. When the trajectory turns on itself or when we have reached the
    maximum trajectory length we return the current proposal.

        The simplectic integrator used to build trajectories.
        The absolute difference in energy above which we consider
        a transition "divergent".
        The maximum number of times we expand the trajectory by
        doubling the number of steps if the trajectory does not
        turn onto itself.


    def kernel(
            rng_key: PRNGKey,
            state: hmc.HMCState,
            logdensity_fn: Callable,
            step_size: float,
            inverse_mass_matrix: Array,
    ) -> Tuple[hmc.HMCState, NUTSInfo]:
        """Generate a new sample with the NUTS kernel."""
        ) = metrics.gaussian_euclidean(inverse_mass_matrix)
        symplectic_integrator = integrator(logdensity_fn, kinetic_energy_fn)
        proposal_generator = build_iterative_nuts_proposal(
        return propose_from_momentum(rng_key, momentum_generator, proposal_generator, state, step_size)

    return kernel

def nuts_proposal(rng_key,
                  initial_state: integrators.IntegratorState,
    initial_termination_state = new_termination_state(initial_state)
    initial_proposal = proposal_generator(initial_state)

    initial_trajectory = trajectory.zero_steps_trajectory(initial_state)

    initial_expansion_state = trajectory.DynamicExpansionState(
        0, initial_proposal, initial_trajectory, initial_termination_state

    expansion_state, info = expand(
        rng_key, initial_expansion_state,, step_size
    is_diverging, is_turning = info
    num_doublings, sampled_proposal, new_trajectory, _ = expansion_state
    # Compute average acceptance probability across entire trajectory,
    # even over subtrees that may have been rejected
    acceptance_rate = (
            jnp.exp(sampled_proposal.sum_log_p_accept) / new_trajectory.num_states

    info = NUTSInfo(

    return sampled_proposal.state, info

def build_iterative_nuts_proposal(
        integrator: Callable,
        kinetic_energy: Callable,
        uturn_check_fn: Callable,
        max_num_expansions: int = 10,
        divergence_threshold: float = 1000,
) -> Callable:
    """Iterative NUTS proposal.

        Symplectic integrator used to build the trajectory step by step.
        Function that computes the kinetic energy.
        Function that determines whether the trajectory is turning on itself
        Size of the integration step.
        The number of sub-trajectory samples we take to build the trajectory.
        Threshold above which we say that there is a divergence.

    A kernel that generates a new chain state and information about the

    ) = termination.iterative_uturn_numpyro(uturn_check_fn)

    trajectory_integrator = trajectory.dynamic_progressive_integration(

    expand = trajectory.dynamic_multiplicative_expansion(

    new, _ = proposal.proposal_generator(trajectory.hmc_energy(kinetic_energy), np.inf)

    return functools.partial(nuts_proposal,
                   new_termination_state=lambda state: new_termination_state(state, max_num_expansions),

So this code looks far more testable than before, let’s try to do that. Lets first look at the components the file has

class NUTSInfo -> no behavior
def build_kernel -> integration test
def propose_from_momentum -> unit test
def nuts_proposal -> unit test
def build_iterative_nuts_proposal -> integration test

I’ll start with nuts_kernel. I could first try to build a test like this one:

class TestProposeFromMomentum(chex.TestCase):
    @chex.variants(with_jit=True, without_jit=True)
    def test_propose_from_momentum(self):
        Given propose_from_momentum
        When calling it
        Then the proposal generator uses the momentum
        to generate a new proposal
        state = HMCState(position=jnp.array([1., 2.]),
        key = jax.random.PRNGKey(42)
        expected_info = jnp.array([1, 2, 3, 4])

        def momentum_generator(key, position):
            return jnp.array([50.0])

        def proposal_generator(key_integrator, integrator_state, step_size):
            return IntegratorState(position=jnp.array([1., 2.]),
                                   logdensity_grad=0.8), expected_info

        def _nuts_kernel(key, state, step_size):
            return propose_from_momentum(rng_key=key,
                                         state=state, step_size=step_size)

        proposal, info = self.variant(_nuts_kernel)(key, state, 30)

        np.testing.assert_allclose(proposal.position, jnp.array([1., 2.]))
        np.testing.assert_allclose(proposal.logdensity, 0.3)
        np.testing.assert_allclose(proposal.logdensity_grad, 0.8)
        np.testing.assert_allclose(expected_info, info)

Check that I have used a Stub for proposal_generator and a Dummy for expected info. ****

It works, and the assertions make sense. But there is room to break the code, let’s try that:

def nuts_kernel(rng_key, momentum_generator, proposal_generator, state, step_size):
    key_momentum, key_integrator = jax.random.split(rng_key, 2)

    position, logdensity, logdensity_grad = state
    momentum = momentum_generator(key_momentum, position)

    integrator_state = integrators.IntegratorState(
        position, momentum+1000000, logdensity, logdensity_grad
    proposal, info = proposal_generator(key_integrator, integrator_state, step_size)
    proposal = hmc.HMCState(
        proposal.position, proposal.logdensity, proposal.logdensity_grad
    return proposal, info

I have on purpose included a regression inside the code. If we rerun the above test, no failure is raised. There’s no assertion whatsoever that the integrator_state passed to the proposal_generator, nor any other parameter is correctly computed, which is part of this functions’ responsibility.

Two solutions to these types of problems: use Chex Runtime Assertions so that proposal_generator becomes a Mock, or use a Fake.

With Fake

class TestProposeFromMomentum(chex.TestCase):
    @chex.variants(with_jit=True, without_jit=True)
    def test_propose_from_momentum_with_fake(self):
        Given propose_from_momentum
        When calling it
        Then the proposal generator uses the momentum
        to generate a new proposal
        state = HMCState(position=jnp.array([1., 2.]),
        key = jax.random.PRNGKey(42)
        expected_info = jnp.array([1, 2, 3, 4])

        def momentum_generator(key, position):
            return jnp.array([50.0])

        def proposal_generator(key_integrator, integrator_state, step_size):
            return IntegratorState(position=(integrator_state.position + jnp.array([1., 2.])) * step_size,
                                   logdensity_grad=0.8), expected_info

        proposal, info = self.variant(functools.partial(propose_from_momentum,

        np.testing.assert_allclose(proposal.position, jnp.array([60., 120.]))
        np.testing.assert_allclose(proposal.logdensity, 0.3)
        np.testing.assert_allclose(proposal.logdensity_grad, 0.8)
        np.testing.assert_allclose(expected_info, info)

With Mock

class TestProposeFromMomentum(chex.TestCase):

    @chex.variants(with_jit=True, without_jit=True)
    def test_propose_from_momentum_with_mock(self):
        Given propose_from_momentum
        When calling it
        Then the proposal generator uses the momentum
        to generate a new proposal
        state = HMCState(position=jnp.array([1., 2.]),
        key = jax.random.PRNGKey(42)
        expected_info = jnp.array([1, 2, 3, 4])

        def momentum_generator(key, position):
            return jnp.array([50.0])

        def proposal_generator(key_integrator, integrator_state, step_size):
            chex.assert_tree_all_close(state.position, integrator_state.position)
            return IntegratorState(position=jnp.array([1., 2.]),
                                   logdensity_grad=0.8), expected_info

        proposal, info = self.variant(functools.partial(propose_from_momentum, momentum_generator=momentum_generator,
                                      )(rng_key=key, state=state, step_size=30)

        np.testing.assert_allclose(proposal.position, jnp.array([1., 2.]))
        np.testing.assert_allclose(proposal.logdensity, 0.3)
        np.testing.assert_allclose(proposal.logdensity_grad, 0.8)
        np.testing.assert_allclose(expected_info, info)

Both approaches have advantages and disadvantages. The Mock is simpler to understand for

someone new to the algorithm. The Fake has the challenge that we need to think the simplest case that makes sense for the algorithm and someone else to understand, so it may require more explanation.

Summary of suggestions so far:

  • Having more smaller tests is going to help us improve developer feedback and make changes with confidence.
  • Splitting factories from functions is going to help us have more comprehensive tests.
  • We can use Chex runtime assertions to enhance our test doubles, considering we can't use testing double libraries within Jax.

[1] [2] [3] [4] [5]

Thank you for the detailed write-up - I will need some time to digest it, but just want to first expressed we appreciate how you put so much thought into it.

