pytensor
pytensor copied to clipboard
Scan save mem rewrite masks issues with `steps=0`
Description
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode
n = pt.iscalar("n")
x0 = pt.vector("x0")
xs, _ = pytensor.scan(lambda xtm1: xtm1 + 1, outputs_info=[x0], n_steps=n)
out = xs[-1] # Invalid when nsteps=0
fn = pytensor.function([n, x0], out)
print(fn(n=0, x0=[0, 1])) # [1. 2.]
fn = pytensor.function([n, x0], out, mode=get_default_mode().excluding("shape_unsafe"))
print(fn(n=0, x0=[0, 1])) # [1. 2.]
fn = pytensor.function([n, x0], out, mode=get_default_mode().excluding("scan_save_mem"))
print(fn(n=0, x0=[0, 1])) # IndexError: index out of bounds
I suspect from this hack: https://github.com/pymc-devs/pytensor/blob/8454c3ba78b1889987a909745f87c8dcab8e48fb/pytensor/scan/rewriting.py#L1438-L1445
But removing this hack leads to some tests failing, so other stuff may be doing wrong assumptions downstream of it (or perhaps inside the rewrite)