pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Scan save mem rewrite masks issues with `steps=0`

Open ricardoV94 opened this issue 9 months ago • 0 comments

Description

import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode

n = pt.iscalar("n")
x0 = pt.vector("x0")
xs, _ = pytensor.scan(lambda xtm1: xtm1 + 1, outputs_info=[x0], n_steps=n)

out = xs[-1]  # Invalid when nsteps=0

fn = pytensor.function([n, x0], out)
print(fn(n=0, x0=[0, 1]))  # [1. 2.]

fn = pytensor.function([n, x0], out, mode=get_default_mode().excluding("shape_unsafe"))
print(fn(n=0, x0=[0, 1]))  # [1. 2.]

fn = pytensor.function([n, x0], out, mode=get_default_mode().excluding("scan_save_mem"))
print(fn(n=0, x0=[0, 1]))  # IndexError: index out of bounds

I suspect from this hack: https://github.com/pymc-devs/pytensor/blob/8454c3ba78b1889987a909745f87c8dcab8e48fb/pytensor/scan/rewriting.py#L1438-L1445

But removing this hack leads to some tests failing, so other stuff may be doing wrong assumptions downstream of it (or perhaps inside the rewrite)

ricardoV94 avatar Mar 12 '25 09:03 ricardoV94