numpyro
numpyro copied to clipboard
Vectorized interpretation for pyro.scan
This proposes to implement a vectorized interpretation of pyro.scan that completely parallelizes over the time axis. This follows a hand-implementation of vectorization in pyro.contrib.epidemiology (in CompartmentalModel._relaxed_model() and ._vectorized_model. This interpretation works only under replay or condition (the standard interpretation of sampling cannot be time-parallelized), therefore this will require three-way interaction between pyro.scan, poutine.condition, and the new interpretation.
Here is a vague sketch of the parallel implementation of pyro.scan:
def vectorized_scan(transition, time, init):
"""
This assumes ``init`` is a dict mapping unindexed sample site name (i.e.
"x" rather than "x_0") to value. TODO generalize to PyTree.
This assumes ``time`` is a range object. TODO generalize to jnp.arange?
"""
# Trace the first step and assume model structure is fixed.
# In pyro.contrib.epideiology we do this once at the start of inference.
# Maybe we could memoize to avoid duplicated execution?
with poutine.block(), poutine.trace() as tr:
t = 0
transition(init, t)
names = [name for name in tr.trace.stochastic_nodes
if name.endswith("_0")]
# The remainder is vectorized over time.
with pyro.plate("time", len(time)): # or maybe jax.vmap
t = slice(0, len(time), 1) # or maybe jnp.arange
# Record vectorized values.
curr = {}
prev = {}
with poutine.block_trace_but_allow_replay_and_condition():
for name in names:
name_0 = "{}_{}".format(name, 0)
name_t = "{}_{}".format(name, t)
site_0 = tr.nodes[name_0]
curr[name] = pyro.sample(name_t, site_0["fn"])
prev[name] = torch.cat([site_0["value"].unsqueeze(0), curr[name]])
# Execute vectorized transition.
transition(prev, t)
return ...
cc @fehiepsi @eb8680
This is nice to have! If you need any functionality from jax/numpyro, just let me know. About block_trace_but_allow_replay_and_condition, probably it is simpler in NumPyro because you can use control_flow primitive (as in numpyro.contrib.control_flow.scan) and decide if substitute_stack is empty or not. If that stack provides all information to do vectorization, then you can run vectorized_scan.
@fritzo Can I take a stab at this? :)
@fehiepsi sure, and let me know if you want me to review any code.