pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Second gradient of simple Scan shows some missing simplifications

Open ricardoV94 opened this issue 2 months ago • 3 comments

Description

import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import Mode

x0 = pt.scalar("x0")
ys, _ = pytensor.scan(
    lambda ytm1: ytm1 ** 2,
    outputs_info=[x0],
    n_steps=4,
    mode=Mode(linker="py", optimizer="fast_run").excluding("fusion"),
)
f = ys[-1]
g = pt.grad(f, x0)
h = pt.grad(g, x0)
mode = Mode(linker="py", optimizer="fast_run").excluding("scan_pushout")
fn = pytensor.function([x0], h, mode=mode)

# Issues: 
#  - Unused outer_in_seqs-1 in Scan{grad_of_grad_of_scan_fn} isn't removed
#      - Probably only becomes useless after inner graph is rewritten
#  - Repeated 2 * inner_in_mit_mot-0-0. Also equivalent add of it twice. Problem remains with fusion
#  - Nested IncSubtensor (sometimes separate by a reverse slice). This can probably be cleaned up quite a lot
#  - Unnecessary ExpandDims on value written by SetSubtensor
#  - Cryptic [3:-6:-1] slice, equivalent to [3:None:-1]
#  - Useless alloc(0, 4)[:4].inc(...)
# Useless sum on length 1 tensor at the end of graph

fn.dprint(print_shape=True, print_op_info=True)
dprint

Sum{axes=None} [id A] shape=() 23
 └─ Subtensor{start:stop:step} [id B] shape=(?,) 22
    ├─ Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id C] shape=(?,) 21 (outer_out_mit_mot-0)
    │  ├─ 4 [id D] shape=() (n_steps)
    │  ├─ Subtensor{start:stop:step} [id E] shape=(?,) 10 (outer_in_seqs-0)
    │  │  ├─ Scan{scan_fn, while_loop=False, inplace=all} [id F] shape=(?,) 8 (outer_out_sit_sot-0)
    │  │  │  ├─ 4 [id G] shape=() (n_steps)
    │  │  │  └─ SetSubtensor{:stop} [id H] shape=(5,) 6 (outer_in_sit_sot-0)
    │  │  │     ├─ AllocEmpty{dtype='float64'} [id I] shape=(5,) 2
    │  │  │     │  └─ 5 [id J] shape=()
    │  │  │     ├─ ExpandDims{axis=0} [id K] shape=(1,) 3
    │  │  │     │  └─ x0 [id L] shape=()
    │  │  │     └─ 1 [id M] shape=()
    │  │  ├─ 3 [id N] shape=()
    │  │  ├─ -6 [id O] shape=()
    │  │  └─ -1 [id P] shape=()
    │  └─ Subtensor{::step} [id Q] shape=(?,) 20 (outer_in_mit_mot-0)
    │     ├─ IncSubtensor{:stop} [id R] shape=(5,) 19
    │     │  ├─ Alloc [id S] shape=(5,) 1
    │     │  │  ├─ [0.] [id T] shape=(1,)
    │     │  │  └─ 5 [id J] shape=()
    │     │  ├─ Subtensor{::step} [id U] shape=(?,) 18
    │     │  │  ├─ IncSubtensor{:stop} [id V] shape=(4,) 17
    │     │  │  │  ├─ Alloc [id W] shape=(4,) 0
    │     │  │  │  │  ├─ [0.] [id T] shape=(1,)
    │     │  │  │  │  └─ 4 [id D] shape=()
    │     │  │  │  ├─ Subtensor{::step} [id X] shape=(?,) 16
    │     │  │  │  │  ├─ Scan{grad_of_grad_of_scan_fn, while_loop=False, inplace=all}.1 [id Y] shape=(?,) 15 (outer_out_nit_sot-0)
    │     │  │  │  │  │  ├─ 4 [id D] shape=() (outer_in_nit_sot-0)
    │     │  │  │  │  │  ├─ Subtensor{:stop} [id Z] shape=(?,) 12 (outer_in_seqs-0)
    │     │  │  │  │  │  │  ├─ Scan{scan_fn, while_loop=False, inplace=all} [id F] shape=(?,) 8 (outer_out_sit_sot-0)
    │     │  │  │  │  │  │  │  └─ ···
    │     │  │  │  │  │  │  └─ 4 [id BA] shape=()
    │     │  │  │  │  │  ├─ Subtensor{start:stop} [id BB] shape=(?,) 11 (outer_in_seqs-1)
    │     │  │  │  │  │  │  ├─ Scan{scan_fn, while_loop=False, inplace=all} [id F] shape=(?,) 8 (outer_out_sit_sot-0)
    │     │  │  │  │  │  │  │  └─ ···
    │     │  │  │  │  │  │  ├─ 1 [id M] shape=()
    │     │  │  │  │  │  │  └─ 5 [id BC] shape=()
    │     │  │  │  │  │  ├─ Subtensor{start:stop:step} [id BD] shape=(?,) 14 (outer_in_seqs-2)
    │     │  │  │  │  │  │  ├─ Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id BE] shape=(?,) 13 (outer_out_mit_mot-0)
    │     │  │  │  │  │  │  │  ├─ 4 [id D] shape=() (n_steps)
    │     │  │  │  │  │  │  │  ├─ Subtensor{start:stop:step} [id E] shape=(?,) 10 (outer_in_seqs-0)
    │     │  │  │  │  │  │  │  │  └─ ···
    │     │  │  │  │  │  │  │  └─ Subtensor{::step} [id BF] shape=(?,) 9 (outer_in_mit_mot-0)
    │     │  │  │  │  │  │  │     ├─ IncSubtensor{start:} [id BG] shape=(?,) 7
    │     │  │  │  │  │  │  │     │  ├─ Alloc [id S] shape=(5,) 1
    │     │  │  │  │  │  │  │     │  │  └─ ···
    │     │  │  │  │  │  │  │     │  ├─ IncSubtensor{i} [id BH] shape=(?,) 4
    │     │  │  │  │  │  │  │     │  │  ├─ Alloc [id W] shape=(4,) 0
    │     │  │  │  │  │  │  │     │  │  │  └─ ···
    │     │  │  │  │  │  │  │     │  │  ├─ 1.0 [id BI] shape=()
    │     │  │  │  │  │  │  │     │  │  └─ -1 [id P] shape=()
    │     │  │  │  │  │  │  │     │  └─ 1 [id M] shape=()
    │     │  │  │  │  │  │  │     └─ -1 [id P] shape=()
    │     │  │  │  │  │  │  ├─ 3 [id N] shape=()
    │     │  │  │  │  │  │  ├─ -6 [id O] shape=()
    │     │  │  │  │  │  │  └─ -1 [id P] shape=()
    │     │  │  │  │  │  ├─ IncSubtensor{:stop} [id BJ] shape=(5,) 5 (outer_in_mit_mot-0)
    │     │  │  │  │  │  │  ├─ Alloc [id S] shape=(5,) 1
    │     │  │  │  │  │  │  │  └─ ···
    │     │  │  │  │  │  │  ├─ [1.] [id BK] shape=(1,)
    │     │  │  │  │  │  │  └─ 1 [id M] shape=()
    │     │  │  │  │  │  └─ 4 [id D] shape=() (outer_in_nit_sot-0)
    │     │  │  │  │  └─ -1 [id P] shape=()
    │     │  │  │  └─ 4 [id BA] shape=()
    │     │  │  └─ -1 [id P] shape=()
    │     │  └─ -1 [id P] shape=()
    │     └─ -1 [id P] shape=()
    ├─ 4 [id BA] shape=()
    ├─ 3 [id N] shape=()
    └─ -1 [id P] shape=()

Inner graphs:

Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id C]
 ← Add [id BL] shape=() (inner_out_mit_mot-0-0)
    ├─ Mul [id BM] shape=()
    │  ├─ 2.0 [id BN] shape=()
    │  ├─ *1-<Scalar(float64, shape=())> [id BO] shape=() -> [id Q] (inner_in_mit_mot-0-0)
    │  └─ *0-<Scalar(float64, shape=())> [id BP] shape=() -> [id E] (inner_in_seqs-0)
    └─ *2-<Scalar(float64, shape=())> [id BQ] shape=() -> [id Q] (inner_in_mit_mot-0-1)

Scan{scan_fn, while_loop=False, inplace=all} [id F]
 ← Sqr [id BR] shape=() (inner_out_sit_sot-0)
    └─ *0-<Scalar(float64, shape=())> [id BP] shape=() -> [id H] (inner_in_sit_sot-0)

Scan{grad_of_grad_of_scan_fn, while_loop=False, inplace=all} [id Y]
 ← Add [id BS] shape=() (inner_out_mit_mot-0-0)
    ├─ Mul [id BT] shape=()
    │  ├─ 2.0 [id BU] shape=()
    │  ├─ *3-<Scalar(float64, shape=())> [id BV] shape=() -> [id BJ] (inner_in_mit_mot-0-0)
    │  └─ *0-<Scalar(float64, shape=())> [id BP] shape=() -> [id Z] (inner_in_seqs-0)
    └─ *4-<Scalar(float64, shape=())> [id BW] shape=() -> [id BJ] (inner_in_mit_mot-0-1)
 ← Add [id BX] shape=() (inner_out_mit_mot-0-1)
    ├─ *3-<Scalar(float64, shape=())> [id BV] shape=() -> [id BJ] (inner_in_mit_mot-0-0)
    └─ *3-<Scalar(float64, shape=())> [id BV] shape=() -> [id BJ] (inner_in_mit_mot-0-0)
 ← Mul [id BY] shape=() (inner_out_nit_sot-0)
    ├─ 2.0 [id BU] shape=()
    ├─ *3-<Scalar(float64, shape=())> [id BV] shape=() -> [id BJ] (inner_in_mit_mot-0-0)
    └─ *2-<Scalar(float64, shape=())> [id BQ] shape=() -> [id BD] (inner_in_seqs-2)

Scan{grad_of_scan_fn, while_loop=False, inplace=all} [id BE]
 ← Add [id BZ] shape=() (inner_out_mit_mot-0-0)
    ├─ Mul [id CA] shape=()
    │  ├─ 2.0 [id CB] shape=()
    │  ├─ *1-<Scalar(float64, shape=())> [id BO] shape=() -> [id BF] (inner_in_mit_mot-0-0)
    │  └─ *0-<Scalar(float64, shape=())> [id BP] shape=() -> [id E] (inner_in_seqs-0)
    └─ *2-<Scalar(float64, shape=())> [id BQ] shape=() -> [id BF] (inner_in_mit_mot-0-1)

ricardoV94 avatar Oct 15 '25 16:10 ricardoV94

TIL it takes 4 Scans to do second order auto-diff :)

ricardoV94 avatar Oct 15 '25 16:10 ricardoV94

TIL it takes 4 Scans to do second order auto-diff :)

Since the backward of a scan is a scan, the number of scans for the nth derivative should be 2^n ?

jessegrabowski avatar Oct 16 '25 14:10 jessegrabowski

Yeah. I'm writing an example on scan autodiff. It could be less than 2^n if some Scans could be merged I guess

ricardoV94 avatar Oct 16 '25 14:10 ricardoV94