aeppl
aeppl copied to clipboard
Transforms fail in FAST_COMPILE
import aesara
import aesara.tensor as at
from aeppl.transforms import TransformValuesRewrite, LogTransform
x_rv = at.random.exponential()
opt = TransformValuesRewrite({x_rv: LogTransform()})
logp, (x_vv,) = aeppl.joint_logprob(x_rv, extra_rewrites=opt)
logp_fn = aesara.function([x_vv], logp, mode="FAST_COMPILE")
aesara.dprint(logp_fn)
Elemwise{add,no_inplace} [id A] 6
|Check{mu > 0} [id B] 5
| |Elemwise{switch,no_inplace} [id C] 4
| | |Elemwise{ge,no_inplace} [id D] 3
| | | |TransformedVariable [id E] 1
| | | | |Elemwise{exp,no_inplace} [id F] 0
| | | | | |<TensorType(float64, ())> [id G]
| | | | |<TensorType(float64, ())> [id G]
| | | |TensorConstant{0.0} [id H]
| | |Elemwise{mul,no_inplace} [id I] 2
| | | |TensorConstant{-1.0} [id J]
| | | |TransformedVariable [id E] 1
| | |TensorConstant{-inf} [id K]
| |TensorConstant{True} [id L]
|<TensorType(float64, ())> [id G]
A TransformedVariable
is still present, but it shouldn't be.
I'm guessing that our canonicalizations still aren't stable between FAST_COMPILE and FAST_RUN and/or some rewrites that we assume are canonicalizations in AePPL actually aren't. We can always address this by adding those rewrites to AePPL's logprob_rewrites_db
, but we should also consider whether or not the relevant rewrites actually should be canonicalizations in Aesara.