Transform graph to make large constants symbolic inputs
Description
JAX jitting can be insanely slow when there are large constants in the graph. We could add a helper to convert any large constants to symbolic inputs (we already did some constant folding work on our end anyway), so JAX can't get hang up on those.
See related discussion on their side: https://github.com/jax-ml/jax/issues/21300
The idea is to have a pytensor.graph.replace.replace_large_constants_by_inputs that returns the graph with constants replaced by PyTensor input variables and the respective values
Here is one approach I got working:
from functools import partial
from pytensor import shared
from pytensor.graph import FunctionGraph
from pytensor.tensor.variable import TensorConstant
from pytensor.compile.mode import Mode
def recreate_fn_with_large_constants_as_shared(fn: "Function", size_threshold: int = 1000):
large_constants_to_shared = {}
fg = fn.maker.fgraph
for var in fg.variables:
if isinstance(var, TensorConstant) and var.data.size > size_threshold:
large_constants_to_shared[var] = shared(var.data, shape=var.type.shape, name=var.name)
if not large_constants_to_shared:
return fn
new_fg = FunctionGraph(
fn.maker.fgraph.inputs,
fn.maker.fgraph.outputs,
copy_inputs=False,
)
new_fg.replace_all(tuple(large_constants_to_shared.items()), import_missing=True)
new_fn = pytensor.function(
fg.inputs,
*new_fg.outputs,
mode=Mode(linker=fn.maker.linker, optimizer=None),
accept_inplace=True,
)
return new_fn
I guess this issue just needs someone to actually do it? It looks like the code you posted would just work.