pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Add rewrite for `log(gamma) -> gammaln`

Open ricardoV94 opened this issue 10 months ago • 0 comments

Description

We're missing this simple rewrite:

import pytensor.tensor as pt
from pytensor.graph import rewrite_graph

x = pt.scalar("x")
out = pt.log(pt.gamma(x))
new_out = rewrite_graph(out, include=("canonicalize", "stabilize", "specialize"))
new_out.dprint()

Can be done easily with PatternNodeRewriter as in

https://github.com/pymc-devs/pytensor/blob/911c6a33c2bea6bf1d5b628154e84c43cbed1c63/pytensor/tensor/rewriting/math.py#L3646-L3651

We could also add rewrites for common combinatorics expressions like

naive_betaln = pt.log((pt.gamma(x) * pt.gamma(y)) / pt.gamma(x + y)
betaln = pt.gammaln(x) + pt.gammaln(y) - pt.gammaln(x + y)

https://github.com/pymc-devs/pytensor/blob/ad55b69f3d13f11c6a9a57823c2a88a966db8b1a/pytensor/tensor/special.py#L799-L804

Or for log(poch): https://github.com/pymc-devs/pytensor/blob/ad55b69f3d13f11c6a9a57823c2a88a966db8b1a/pytensor/tensor/special.py#L767-L772

For these more general cases we can probably use something more flexible than the PatternNodeRewriter. We want to apply as long as we know all the terms inside are factorials/gammas/exps (positive things that easily blow up). This is a narrow/easier subset of https://github.com/pymc-devs/pytensor/discussions/177

ricardoV94 avatar Jan 31 '25 10:01 ricardoV94