pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Rewrite nested inc/set_subtensor on zeros

Open ricardoV94 opened this issue 9 months ago • 0 comments

Description

The gradient of x[1:][-1] has two successive inc_subtensor on zeros of increasing size. We should collapse them, as happens if you take the gradient of the single slice that corresponds to the two slices together x[-1].

This shows up in the gradient of Scans for the last outputs of a recurring sequence.

import pytensor.tensor as pt
from pytensor.graph import rewrite_graph

x = pt.vector("x", shape=(4,))
out = x[1:][-1]  # When you select the last entry of a scan sitsot this shows up in the graph
g = pt.grad(out, x)
rewrite_graph(g, include=("fast_run",), exclude=("inplace",)).dprint()
# IncSubtensor{start:} [id A]
#  ├─ Alloc [id B]
#  │  ├─ [0.] [id C]
#  │  └─ 4 [id D]
#  ├─ IncSubtensor{i} [id E]
#  │  ├─ Alloc [id F]
#  │  │  ├─ [0.] [id C]
#  │  │  └─ 3 [id G]
#  │  ├─ 1.0 [id H]
#  │  └─ -1 [id I]
#  └─ 1 [id J]

new_out = rewrite_graph(out, include=("ShapeOpt", "canonicalize"))
new_g = pt.grad(new_out, x)
rewrite_graph(new_g, include=("fast_run",), exclude=("inplace",)).dprint()
# IncSubtensor{i} [id A]
#  ├─ Alloc [id B]
#  │  ├─ [0.] [id C]
#  │  └─ 4 [id D]
#  ├─ 1.0 [id E]
#  └─ 3 [id F]

I think the rule is incsubtensor on the larger buffer with the negative inner index, or outer start + positive inner index. We may also want to handle the unknown sign symbolically, but even the constant case would be a nice start.

Bonus points if we can combine it with an outer flip that the scan gradient also does:

import pytensor.tensor as pt
from pytensor.graph import rewrite_graph

x = pt.vector("x", shape=(4,))
out = x[1:][-1]
new_out = rewrite_graph(out, include=("ShapeOpt", "canonicalize"))
new_g = pt.grad(new_out, x)[::-1]
rewrite_graph(new_g, include=("fast_run",), exclude=("inplace",)).dprint()
# Subtensor{::step} [id A]
#  ├─ IncSubtensor{i} [id B]
#  │  ├─ Alloc [id C]
#  │  │  ├─ [0.] [id D]
#  │  │  └─ 4 [id E]
#  │  ├─ 1.0 [id F]
#  │  └─ 3 [id G]
#  └─ -1 [id H]

Which should be doable by flipping the indices. Not as important since the flip is just a cheap view on the input

ricardoV94 avatar Mar 12 '25 23:03 ricardoV94