pymc
pymc copied to clipboard
Implement logprob inference for binary operations
Description
Using CDFs it should be simple to derive the logp of graphs of the form:
import pymc as pm
x = pm.Normal.dist()
y = x < 0.5
pm.logp(y, value=1).eval() # same as pm.logcdf(x, 0.5)
pm.logp(y, value=0).eval() # same as log(1 - exp(pm.logcdf(x, 0.5)))
https://github.com/pymc-devs/pymc/blob/main/pymc/logprob/censoring.py includes rewrites of other operations that rely on CDF such as clip, and round.
More challenging, but also fun, would be to support all and any, whose logp should be the sum of the logp that all binary variables evaluate to True or False: any = 1 - all(x==False)
For ordering, and min/max see https://github.com/pymc-devs/pymc/issues/6350
@ricardoV94 I would like to work on this issue starting with the logical comparison Ops (gt, lt, ge, le), but I needed some clarification about the value of logp for y.
pm.logp(y, 0).eval() # same as pm.logcdf(y, 0.5)
Since y is a bool type, how exactly is pm.logcdf for y defined? And consequently, what should pm.logcdf(y, 0.5) evaluate to?
@shreyas3156 My code comments where wrong, I updated them.
Anyway, the idea is
pm.logp(x < 0.5, True) == pm.logcdf(x, 0.5).
Does that make sense? Just saying the probability of a simple constant inequality (the constant may be another valued RV) is the same as some CDF expression on the underlying variable.
Yes, that makes perfect sense now. Thanks!
Next interesting binary Ops would be and, or, all and any