C-Cache not shared across functions with same operations but distinct constants
Description
from pytensor.graph import FunctionGraph
import pytensor.scalar as ps
from pytensor.link.c.basic import CLinker
x = ps.float64("x")
o = x + 1
cl = CLinker().accept(FunctionGraph([x], [o]))
cl.cmodule_key()
from pytensor.graph import FunctionGraph
import pytensor.tensor as pt
from pytensor.link.c.basic import CLinker
x = pt.vector("x")
y = pt.vector("y")
z = pt.vector("z")
c1 = pt.constant([1, 1, 1, 1, 1])
c2 = pt.constant([1, 1, 1, 1, 2])
# Same source code
assert (
CLinker().accept(FunctionGraph([x, y], [x + y])).get_src_code()
== CLinker().accept(FunctionGraph([x, z], [x + z])).get_src_code()
)
# Same source code
assert (
CLinker().accept(FunctionGraph([x], [x + c1])).get_src_code()
== CLinker().accept(FunctionGraph([x], [x + c2])).get_src_code()
)
# Same hashing
assert (
CLinker().accept(FunctionGraph([x, y], [x + y])).cmodule_key()
== CLinker().accept(FunctionGraph([x, z], [x + z])).cmodule_key()
)
# Distinct hashing
assert (
CLinker().accept(FunctionGraph([x], [x + c1])).cmodule_key()
== CLinker().accept(FunctionGraph([x], [x + c2])).cmodule_key()
)
This is wasteful. It's common to have Ops with the same code but distinct constant inputs. The only case where constants matter are:
- Scalar constants which have a
c_literalthat is used directly in the generated code. It should be easy to add a work-around for this - Operations that use node constant info to specialize the C-code. This is something that was never done before, but I started doing for Join/Split (most times axis is constant), and for AdvancedSubtensor1/AdvancedIncSubtensor1, to decide whether we need to check for negative / invalid indices.
This information could/should be encoded in c_code_cache_version_apply which can use node information. If we handle those cases we can reuse much more compiled functions
The integration of constants in the cmodule_key happens here: https://github.com/pymc-devs/pytensor/blob/27c21cd65e347f8245012eb65486c26dc341915c/pytensor/link/c/basic.py#L1395-L1412
The introduction of scalar literals (only if they are not fgraph inputs: https://github.com/pymc-devs/pytensor/blob/27c21cd65e347f8245012eb65486c26dc341915c/pytensor/link/c/basic.py#L624-L630
https://github.com/pymc-devs/pytensor/blob/27c21cd65e347f8245012eb65486c26dc341915c/pytensor/link/c/basic.py#L693-L695
Exploitation by the C-code of an Op: https://github.com/pymc-devs/pytensor/blob/27c21cd65e347f8245012eb65486c26dc341915c/pytensor/tensor/basic.py#L2331-L2349