pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Should branching logps accept constants

Open ricardoV94 opened this issue 9 months ago • 0 comments

Description

The following example illustrates a restriction in the current logp derivations, when branch includes constants

import pytensor.tensor as pt
import pymc as pm

t = pt.arange(10)
cat = pm.Categorical.dist(p=[0.5, 0.5], shape=(10,))
# cat_fixed = pt.where(t > 5, cat, -1)  # Not accepted because -1 is not measurable
cat_fixed = pt.where(t > 5, cat, pm.DiracDelta.dist(-1, shape=cat.shape))  # fine
pm.logp(cat_fixed, cat_fixed.type())

Should we allow it? This also applies to operations like join and make_vector where one may combine measurable and constant inputs.

If we allow it should we also allow broadcasting? This is currently not allowed (hence the need for shape=cat.shape) because the logp of broadcasted operations can be tricky to handle systematically, but for constants it may be fine?

ricardoV94 avatar Mar 04 '25 11:03 ricardoV94