aeppl icon indicating copy to clipboard operation
aeppl copied to clipboard

Support "transitive" `Scan` log-probabilities

Open brandonwillard opened this issue 2 years ago • 0 comments

Currently, Scan log-probability support only handles cases in which the MeasurableVariable is created inside the body/step function of the Scan, and not when the body/step function simply references a MeasurableVariable that is being iterated over by the Scan.

For example, the following is not supported:

import aesara
import aesara.tensor as at

from aeppl.joint_logprob import factorized_joint_logprob


srng = at.random.RandomStream(seed=2320)
N = 10

Y_rv = srng.normal(0, 1, size=N, name="Y")


def step_fn(y_t):
    return y_t


Y_1T_rv, _ = aesara.scan(
    fn=step_fn,
    sequences=[Y_rv],
    strict=True,
)

y_vv = Y_1T_rv.clone()
y_vv.name = "y"

logp_parts = factorized_joint_logprob({Y_1T_rv: y_vv})

This example is very trivial, but, if we change step_fn so that it performs a supported, measurable operation on y_t (e.g. indexing a mixture), it wouldn't work for the same reason.

When a value is assigned to Scan output terms like Y_1T_rv, we could "push" the relevant sequences inputs into the step function. In other words, we could construct the type of graph we currently handle.

Working from the example above, we would rewrite the Scan into something like the following:

# Apply a rewrite like `local_rv_size_lift` to get properly `size`-broadcasted parameters
# in a new variable `new_Y_rv`
mu_bcast, sigma_bcast = new_Y_rv.owner.inputs[3:]

def new_step_fn(mu_t, sigma_t)
    return Y_rv.owner.op(mu_t, sigma_t, name="Y[t]")

new_Y_1T_rv, _ = aesara.scan(
    fn=new_step_fn,
    sequences=[mu_bcast, sigma_bcast],
    strict=True,
)

brandonwillard avatar Oct 15 '21 00:10 brandonwillard