pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Transform graph to make large constants symbolic inputs

Open ricardoV94 opened this issue 10 months ago • 1 comments

Description

JAX jitting can be insanely slow when there are large constants in the graph. We could add a helper to convert any large constants to symbolic inputs (we already did some constant folding work on our end anyway), so JAX can't get hang up on those.

See related discussion on their side: https://github.com/jax-ml/jax/issues/21300

The idea is to have a pytensor.graph.replace.replace_large_constants_by_inputs that returns the graph with constants replaced by PyTensor input variables and the respective values

ricardoV94 avatar Feb 19 '25 13:02 ricardoV94

Here is one approach I got working:

from functools import partial
from pytensor import shared
from pytensor.graph import FunctionGraph
from pytensor.tensor.variable import TensorConstant
from pytensor.compile.mode import Mode

def recreate_fn_with_large_constants_as_shared(fn: "Function", size_threshold: int = 1000):
    large_constants_to_shared = {}

    fg = fn.maker.fgraph
    for var in fg.variables:
        if isinstance(var, TensorConstant) and var.data.size > size_threshold:
            large_constants_to_shared[var] = shared(var.data, shape=var.type.shape, name=var.name)

    if not large_constants_to_shared:
        return fn

    new_fg = FunctionGraph(
        fn.maker.fgraph.inputs,
        fn.maker.fgraph.outputs,
        copy_inputs=False,
    )
    new_fg.replace_all(tuple(large_constants_to_shared.items()), import_missing=True)

    new_fn = pytensor.function(
        fg.inputs,
        *new_fg.outputs,
        mode=Mode(linker=fn.maker.linker, optimizer=None),
        accept_inplace=True,
    )

    return new_fn

ricardoV94 avatar Feb 28 '25 12:02 ricardoV94

I guess this issue just needs someone to actually do it? It looks like the code you posted would just work.

jessegrabowski avatar May 30 '25 07:05 jessegrabowski