RFC: tracing variables within `JointDistribution`s
Background
Sometimes it would be useful to be able to trace intermediate values computed as part of a JointDistribution* model. The current solution to this is to make use of tfd.Deterministic. As an example, supposing we would like to trace the mean of a simple linear regression on a single feature, a user might write:
@tfd.JointDistributionCoroutine
def model():
intercept = yield tfd.Normal(loc=0.0, scale=1.0, name="intercept")
slope = yield tfd.Normal(loc=0.0, scale=1.0, name="slope")
mean = tfd.Deterministic(intercept + slope * feature, name="mean")
yield tfd.Normal(loc=mean, scale=1.0, name="response")
A call to model.sample() will return a named tuple including the value of mean, conditional on alpha and beta (and feature). However, if we wish to compute the log probability density of the model given response, intercept, and slope, we also have to pass into model.log_prob a value of mean. Here, mean must be consistent with intercept and slope, which requires the user to duplicate the expression for mean outside the model object. e.g.
intercept = 0.1
slope = 0.2
feature = 0.5
response = 0.21
# Since `mean` is deterministic, we should not have to re-compute it outside of `model`
mean = intercept + slope * feature
lp = model.log_prob(intercept=intercept, slope=slope, mean=mean, response=response)
This seems wasteful in terms of keystrokes, but also error-prone if model changes.
Suggested solution
A potential solution would be to include a sub-class similar to JointDistribution.Root called JointDistribution.Trace which would flag an expression for tracing in the forward generating process (i.e. model.sample()), but exclude the associated variable from the CDF/CMF and PDF/PMF-related methods. Thus we could write:
@tfd.JointDistributionCoroutine
def model():
intercept = yield tfd.Normal(loc=0.0, scale=1.0, name="intercept")
slope = yield tfd.Normal(loc=0.0, scale=1.0, name="slope")
mean =Trace(intercept + slope * feature, name="mean")
yield tfd.Normal(loc=mean, scale=1.0, name="response")
draw = model.sample(seed=[0,0])
# `mean` is simply ignored
model.log_prob(draw)
# `mean` does not have to be supplied
model.log_prob(intercept=draw.intercept, slope=draw.slope, response=draw.response)
Does this seem like a feasible addition? (I may have some resource to devote to it)