numpyro icon indicating copy to clipboard operation
numpyro copied to clipboard

Vectorized interpretation for pyro.scan

Open fritzo opened this issue 5 years ago • 3 comments

This proposes to implement a vectorized interpretation of pyro.scan that completely parallelizes over the time axis. This follows a hand-implementation of vectorization in pyro.contrib.epidemiology (in CompartmentalModel._relaxed_model() and ._vectorized_model. This interpretation works only under replay or condition (the standard interpretation of sampling cannot be time-parallelized), therefore this will require three-way interaction between pyro.scan, poutine.condition, and the new interpretation.

Here is a vague sketch of the parallel implementation of pyro.scan:

def vectorized_scan(transition, time, init):
    """
    This assumes ``init`` is a dict mapping unindexed sample site name (i.e.
    "x" rather than "x_0") to value. TODO generalize to PyTree.
    This assumes ``time`` is a range object. TODO generalize to jnp.arange?
    """
    # Trace the first step and assume model structure is fixed.
    # In pyro.contrib.epideiology we do this once at the start of inference.
    # Maybe we could memoize to avoid duplicated execution?
    with poutine.block(), poutine.trace() as tr:
        t = 0
        transition(init, t)
    names = [name for name in tr.trace.stochastic_nodes
             if name.endswith("_0")]

    # The remainder is vectorized over time.
    with pyro.plate("time", len(time)):  # or maybe jax.vmap
        t = slice(0, len(time), 1)  # or maybe jnp.arange

        # Record vectorized values.
        curr = {}
        prev = {}
        with poutine.block_trace_but_allow_replay_and_condition():
            for name in names:
                name_0 = "{}_{}".format(name, 0)
                name_t = "{}_{}".format(name, t)
                site_0 = tr.nodes[name_0]
                curr[name] = pyro.sample(name_t, site_0["fn"])
                prev[name] = torch.cat([site_0["value"].unsqueeze(0), curr[name]])

        # Execute vectorized transition.
        transition(prev, t)

    return ...

cc @fehiepsi @eb8680

fritzo avatar Jul 15 '20 17:07 fritzo

This is nice to have! If you need any functionality from jax/numpyro, just let me know. About block_trace_but_allow_replay_and_condition, probably it is simpler in NumPyro because you can use control_flow primitive (as in numpyro.contrib.control_flow.scan) and decide if substitute_stack is empty or not. If that stack provides all information to do vectorization, then you can run vectorized_scan.

fehiepsi avatar Jul 16 '20 19:07 fehiepsi

@fritzo Can I take a stab at this? :)

fehiepsi avatar May 26 '21 04:05 fehiepsi

@fehiepsi sure, and let me know if you want me to review any code.

fritzo avatar May 28 '21 21:05 fritzo