pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Missed scan rewrites

Open aseyboldt opened this issue 1 year ago • 4 comments

Description

There are two issues with the code generated by this snippet:

def update(x):
    return pt.exp(x) - 5

x_init = pt.vector("x_init", shape=(7,))
x_init_tangent = pt.vector("x_init_tangent", shape=(7,))
seq, updates = pytensor.scan(update, outputs_info=[x_init], n_steps=10)
outputs = seq[-1]
output_tangent = pytensor.Rop(outputs, x_init, eval_points=x_init_tangent)

with pytensor.config.change_flags(optimizer_verbose=False):
    func = pytensor.function([x_init, x_init_tangent], [outputs, output_tangent], mode=pytensor.compile.mode.get_mode("FAST_RUN"))

pytensor.dprint(func, print_type=True, print_destroy_map=True)
Subtensor{i} [id A] <Vector(float64, shape=(7,))> 13
 ├─ Scan{scan_fn&rop_of_scan_fn, while_loop=False, inplace=all}.0 [id B] <Matrix(float64, shape=(?, 7))> 12
 │  ├─ 10 [id C] <Scalar(int8, shape=())>
 │  ├─ SetSubtensor{:stop} [id D] <Matrix(float64, shape=(2, 7))> 11
 │  │  ├─ AllocEmpty{dtype='float64'} [id E] <Matrix(float64, shape=(2, 7))> 10
 │  │  │  ├─ 2 [id F] <Scalar(int64, shape=())>
 │  │  │  └─ 7 [id G] <Scalar(int64, shape=())>
 │  │  ├─ SpecifyShape [id H] <Matrix(float64, shape=(1, 7))> 7
 │  │  │  ├─ Unbroadcast{0} [id I] <Matrix(float64, shape=(?, 7))> 6
 │  │  │  │  └─ ExpandDims{axis=0} [id J] <Matrix(float64, shape=(1, 7))> 5
 │  │  │  │     └─ x_init [id K] <Vector(float64, shape=(7,))>
 │  │  │  ├─ 1 [id L] <Scalar(int8, shape=())>
 │  │  │  └─ 7 [id M] <Scalar(int8, shape=())>
 │  │  └─ 1 [id N] <int64>
 │  ├─ SetSubtensor{:stop} [id O] <Matrix(float64, shape=(1, 7))> 9
 │  │  ├─ AllocEmpty{dtype='float64'} [id P] <Matrix(float64, shape=(1, 7))> 8
 │  │  │  ├─ 1 [id Q] <Scalar(int64, shape=())>
 │  │  │  └─ 7 [id G] <Scalar(int64, shape=())>
 │  │  ├─ SpecifyShape [id H] <Matrix(float64, shape=(1, 7))> 7
 │  │  │  └─ ···
 │  │  └─ 1 [id N] <int64>
 │  └─ SetSubtensor{:stop} [id R] <Matrix(float64, shape=(2, 7))> 4
 │     ├─ AllocEmpty{dtype='float64'} [id S] <Matrix(float64, shape=(2, 7))> 3
 │     │  ├─ 2 [id T] <Scalar(int64, shape=())>
 │     │  └─ 7 [id G] <Scalar(int64, shape=())>
 │     ├─ SpecifyShape [id U] <Matrix(float64, shape=(1, 7))> 2
 │     │  ├─ Unbroadcast{0} [id V] <Matrix(float64, shape=(?, 7))> 1
 │     │  │  └─ ExpandDims{axis=0} [id W] <Matrix(float64, shape=(1, 7))> 0
 │     │  │     └─ x_init_tangent [id X] <Vector(float64, shape=(7,))>
 │     │  ├─ 1 [id L] <Scalar(int8, shape=())>
 │     │  └─ 7 [id M] <Scalar(int8, shape=())>
 │     └─ 1 [id N] <int64>
 └─ 1 [id Y] <uint8>
Subtensor{i} [id Z] <Vector(float64, shape=(7,))> 14
 ├─ Scan{scan_fn&rop_of_scan_fn, while_loop=False, inplace=all}.2 [id B] <Matrix(float64, shape=(?, 7))> 12
 │  └─ ···
 └─ 1 [id Y] <uint8>

Inner graphs:

Scan{scan_fn&rop_of_scan_fn, while_loop=False, inplace=all} [id B]
 ← Composite{(exp(i0) - 5.0)} [id BA] <Vector(float64, shape=(7,))>
    └─ *0-<Vector(float64, shape=(7,))> [id BB] <Vector(float64, shape=(7,))> -> [id D]
 ← Composite{...}.0 [id BC] <Vector(float64, shape=(7,))>
    ├─ *1-<Vector(float64, shape=(7,))> [id BD] <Vector(float64, shape=(7,))> -> [id O]
    └─ *2-<Vector(float64, shape=(7,))> [id BE] <Vector(float64, shape=(7,))> -> [id R]
 ← Composite{...}.1 [id BC] <Vector(float64, shape=(7,))>
    └─ ···

Composite{(exp(i0) - 5.0)} [id BA]
 ← sub [id BF] <float64> 'o0'
    ├─ exp [id BG] <float64>
    │  └─ i0 [id BH] <float64>
    └─ 5.0 [id BI] <float64>

Composite{...} [id BC]
 ← sub [id BJ] <float64> 'o0'
    ├─ exp [id BK] <float64> 't3'
    │  └─ i0 [id BL] <float64>
    └─ 5.0 [id BM] <float64>
 ← mul [id BN] <float64> 'o1'
    ├─ exp [id BK] <float64> 't3'
    │  └─ ···
    └─ i1 [id BO] <float64>
  • Intermediate arrays have shape (2, 7) instead of (1, 7) (this also happens without the Rop
  • We compute exp(5) twice in the loop

cc @ricardoV94

aseyboldt avatar May 28 '24 16:05 aseyboldt

Regarding the (2,7) instead of (1, 7), my guess is this may be to facilitate inplace rewrites without the need for deepcopy? For instance the inplace logic for CompositeOps with multiple outputs is not trivial, because we have to make sure the input is not modified in place when it is still needed to compute other outputs (#138 )?

ricardoV94 avatar Jun 26 '24 09:06 ricardoV94

This comment here confirms my suspicion of why the buffer has a +1 length: https://github.com/pymc-devs/pytensor/blob/d7c787555ffc51b64c40779d7c2d20dbea32c96f/pytensor/scan/op.py#L33-L42

It's to avoid overwriting the input too soon. This is the same reason why we only implemented Inplace for Composites with a single output: https://github.com/pymc-devs/pytensor/issues/138

So we can be more aggressive if there is a single output, but perhaps not if there are more, at least not without inspecting the inner graph carefully.

ricardoV94 avatar Nov 12 '24 15:11 ricardoV94

This concern shouldn't apply to the JIT backends though, as there's no mechanism to provide the output buffer to the inner functions in those backends, so the optimization can be more aggressive

ricardoV94 avatar Mar 07 '25 09:03 ricardoV94

The buffer size is tweaked for the JIT backends in #1281

The other issue with repeated computations is more tricky. There's no rewrite to identify when a trace is identical to the other. For this to be the case it must have the same initial state and the same output expression.

Here is a more direct example:

import pytensor
import pytensor.tensor as pt

def update(x, y):
    return pt.exp(x) - 5, pt.exp(y) - 5

x_init = pt.vector("x_init", shape=(7,))
[sx, sy], updates = pytensor.scan(update, outputs_info=[x_init, x_init], n_steps=10)

with pytensor.config.change_flags(optimizer_verbose=False):
    func = pytensor.function([x_init], [sx, sy])

func.dprint(print_shape=True, print_op_info=True)

# Subtensor{start:} [id A] shape=(?, 7) 8
#  ├─ Scan{scan_fn, while_loop=False, inplace=all}.0 [id B] shape=(?, 7) 6 (outer_out_sit_sot-0)
#  │  ├─ 10 [id C] shape=() (n_steps)
#  │  ├─ SetSubtensor{:stop} [id D] shape=(11, 7) 4 (outer_in_sit_sot-0)
#  │  │  ├─ AllocEmpty{dtype='float64'} [id E] shape=(11, 7) 1
#  │  │  │  ├─ 11 [id F] shape=()
#  │  │  │  └─ 7 [id G] shape=()
#  │  │  ├─ SpecifyShape [id H] shape=(1, 7) 3
#  │  │  │  ├─ Unbroadcast{0} [id I] shape=(?, 7) 2
#  │  │  │  │  └─ ExpandDims{axis=0} [id J] shape=(1, 7) 0
#  │  │  │  │     └─ x_init [id K] shape=(7,)
#  │  │  │  ├─ 1 [id L] shape=()
#  │  │  │  └─ 7 [id M] shape=()
#  │  │  └─ 1 [id N] shape=()
#  │  └─ DeepCopyOp [id O] shape=(11, 7) 5 (outer_in_sit_sot-1)
#  │     └─ SetSubtensor{:stop} [id D] shape=(11, 7) 4
#  │        └─ ···
#  └─ 1 [id N] shape=()
# Subtensor{start:} [id P] shape=(?, 7) 7
#  ├─ Scan{scan_fn, while_loop=False, inplace=all}.1 [id B] shape=(?, 7) 6 (outer_out_sit_sot-1)
#  │  └─ ···
#  └─ 1 [id N] shape=()
# Inner graphs:
# Scan{scan_fn, while_loop=False, inplace=all} [id B]
#  ← Composite{(-5.0 + exp(i0))} [id Q] shape=(7,) (inner_out_sit_sot-0)
#     └─ *0-<Vector(float64, shape=(7,))> [id R] shape=(7,) -> [id D] (inner_in_sit_sot-0)
#  ← Composite{(-5.0 + exp(i0))} [id S] shape=(7,) (inner_out_sit_sot-1)
#     └─ *1-<Vector(float64, shape=(7,))> [id T] shape=(7,) -> [id O] (inner_in_sit_sot-1)

PyTensor knows the initial state is the same, so much it has to introduce a DeepCopy when it adds the inplace optimization to Scan itself. But it never tries to see if the inner graph for repeated inputs is identical.

It wouldn't be too hard to implement such a rewrite, but is it common enough for us to bother?

ricardoV94 avatar Mar 10 '25 15:03 ricardoV94