Specialize away Split
Split shows up in the graph of Join. It's just a fancy sequence of slice operations. np.split is just a thin wrapper that does this:
https://github.com/numpy/numpy/blob/e7a123b2d3eca9897843791dd698c1803d9a39c2/numpy/lib/_shape_base_impl.py#L789-L796
If it was not for the possible dynamic axis (which Join also supports) we wouldn't need Split at all.
We may still want it for the second order derivatives of join graphs. The gradient over Split is more clean than the eager gradient over multiple subtensors, it's just the reverse join on the output gradients.
I added a specialization rewrite that converts Split to the respective Subtensor graph. I see speedups in all backends.
The C backend wasn't really capable of returning a view of the inputs, so this optimization avoids as many copies as there are splits:
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_mode
x = pt.vector("x", shape=(1000,), dtype=int)
ys = pt.split(x, (250, 250, 500), 3)
fn = pytensor.function(
# Avoid deepcopies
[pytensor.In(x, borrow=True)],
[pytensor.Out(y, borrow=True) for y in ys],
mode=get_mode("NUMBA"), #.excluding("split_to_subtensor"),
trust_input=True,
)
fn.dprint(print_view_map=True)
x_test = np.arange(1000)
fn(x_test)
%timeit fn(x_test)
📚 Documentation preview 📚: https://pytensor--1334.org.readthedocs.build/en/1334/