pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Replace 0 sized nodes by zeros

Open ricardoV94 opened this issue 9 months ago • 0 comments

Description

If we have any empty operations, we can truncate the whole graph above it:

import pytensor
import pytensor.tensor as pt

x = pt.vector("x", shape=(10,))
out = pt.add.outer(x, x)[9:-1]

fn = pytensor.function([x], out)
fn.dprint(print_shape=True)
# Subtensor{start:stop} [id A] shape=(0, 10) 3
#  ├─ Add [id B] shape=(10, 10) 2
#  │  ├─ ExpandDims{axis=1} [id C] shape=(10, 1) 1
#  │  │  └─ x [id D] shape=(10,)
#  │  └─ ExpandDims{axis=0} [id E] shape=(1, 10) 0
#  │     └─ x [id D] shape=(10,)
#  ├─ 9 [id F] shape=()
#  └─ -1 [id G] shape=()

PyTensor knows that the shape=(0, 10). We can replace the graph by just zeros. Worse thing it could do would be to hide shape errors, so we can add the shape_unsafe tag.

We need to be careful with Operations that cannot figure out the output shape without computing itself (such as while Scan). We can ask the ShapeFeature for the shape of the variable, and if the variable is still in the shape graph we know we can't really avoid computing the graph. Otherwise just replace it by zeros.

ricardoV94 avatar Mar 17 '25 17:03 ricardoV94