Missed scan rewrites
Description
There are two issues with the code generated by this snippet:
def update(x):
return pt.exp(x) - 5
x_init = pt.vector("x_init", shape=(7,))
x_init_tangent = pt.vector("x_init_tangent", shape=(7,))
seq, updates = pytensor.scan(update, outputs_info=[x_init], n_steps=10)
outputs = seq[-1]
output_tangent = pytensor.Rop(outputs, x_init, eval_points=x_init_tangent)
with pytensor.config.change_flags(optimizer_verbose=False):
func = pytensor.function([x_init, x_init_tangent], [outputs, output_tangent], mode=pytensor.compile.mode.get_mode("FAST_RUN"))
pytensor.dprint(func, print_type=True, print_destroy_map=True)
Subtensor{i} [id A] <Vector(float64, shape=(7,))> 13
├─ Scan{scan_fn&rop_of_scan_fn, while_loop=False, inplace=all}.0 [id B] <Matrix(float64, shape=(?, 7))> 12
│ ├─ 10 [id C] <Scalar(int8, shape=())>
│ ├─ SetSubtensor{:stop} [id D] <Matrix(float64, shape=(2, 7))> 11
│ │ ├─ AllocEmpty{dtype='float64'} [id E] <Matrix(float64, shape=(2, 7))> 10
│ │ │ ├─ 2 [id F] <Scalar(int64, shape=())>
│ │ │ └─ 7 [id G] <Scalar(int64, shape=())>
│ │ ├─ SpecifyShape [id H] <Matrix(float64, shape=(1, 7))> 7
│ │ │ ├─ Unbroadcast{0} [id I] <Matrix(float64, shape=(?, 7))> 6
│ │ │ │ └─ ExpandDims{axis=0} [id J] <Matrix(float64, shape=(1, 7))> 5
│ │ │ │ └─ x_init [id K] <Vector(float64, shape=(7,))>
│ │ │ ├─ 1 [id L] <Scalar(int8, shape=())>
│ │ │ └─ 7 [id M] <Scalar(int8, shape=())>
│ │ └─ 1 [id N] <int64>
│ ├─ SetSubtensor{:stop} [id O] <Matrix(float64, shape=(1, 7))> 9
│ │ ├─ AllocEmpty{dtype='float64'} [id P] <Matrix(float64, shape=(1, 7))> 8
│ │ │ ├─ 1 [id Q] <Scalar(int64, shape=())>
│ │ │ └─ 7 [id G] <Scalar(int64, shape=())>
│ │ ├─ SpecifyShape [id H] <Matrix(float64, shape=(1, 7))> 7
│ │ │ └─ ···
│ │ └─ 1 [id N] <int64>
│ └─ SetSubtensor{:stop} [id R] <Matrix(float64, shape=(2, 7))> 4
│ ├─ AllocEmpty{dtype='float64'} [id S] <Matrix(float64, shape=(2, 7))> 3
│ │ ├─ 2 [id T] <Scalar(int64, shape=())>
│ │ └─ 7 [id G] <Scalar(int64, shape=())>
│ ├─ SpecifyShape [id U] <Matrix(float64, shape=(1, 7))> 2
│ │ ├─ Unbroadcast{0} [id V] <Matrix(float64, shape=(?, 7))> 1
│ │ │ └─ ExpandDims{axis=0} [id W] <Matrix(float64, shape=(1, 7))> 0
│ │ │ └─ x_init_tangent [id X] <Vector(float64, shape=(7,))>
│ │ ├─ 1 [id L] <Scalar(int8, shape=())>
│ │ └─ 7 [id M] <Scalar(int8, shape=())>
│ └─ 1 [id N] <int64>
└─ 1 [id Y] <uint8>
Subtensor{i} [id Z] <Vector(float64, shape=(7,))> 14
├─ Scan{scan_fn&rop_of_scan_fn, while_loop=False, inplace=all}.2 [id B] <Matrix(float64, shape=(?, 7))> 12
│ └─ ···
└─ 1 [id Y] <uint8>
Inner graphs:
Scan{scan_fn&rop_of_scan_fn, while_loop=False, inplace=all} [id B]
← Composite{(exp(i0) - 5.0)} [id BA] <Vector(float64, shape=(7,))>
└─ *0-<Vector(float64, shape=(7,))> [id BB] <Vector(float64, shape=(7,))> -> [id D]
← Composite{...}.0 [id BC] <Vector(float64, shape=(7,))>
├─ *1-<Vector(float64, shape=(7,))> [id BD] <Vector(float64, shape=(7,))> -> [id O]
└─ *2-<Vector(float64, shape=(7,))> [id BE] <Vector(float64, shape=(7,))> -> [id R]
← Composite{...}.1 [id BC] <Vector(float64, shape=(7,))>
└─ ···
Composite{(exp(i0) - 5.0)} [id BA]
← sub [id BF] <float64> 'o0'
├─ exp [id BG] <float64>
│ └─ i0 [id BH] <float64>
└─ 5.0 [id BI] <float64>
Composite{...} [id BC]
← sub [id BJ] <float64> 'o0'
├─ exp [id BK] <float64> 't3'
│ └─ i0 [id BL] <float64>
└─ 5.0 [id BM] <float64>
← mul [id BN] <float64> 'o1'
├─ exp [id BK] <float64> 't3'
│ └─ ···
└─ i1 [id BO] <float64>
- Intermediate arrays have shape (2, 7) instead of (1, 7) (this also happens without the Rop
- We compute exp(5) twice in the loop
cc @ricardoV94
Regarding the (2,7) instead of (1, 7), my guess is this may be to facilitate inplace rewrites without the need for deepcopy? For instance the inplace logic for CompositeOps with multiple outputs is not trivial, because we have to make sure the input is not modified in place when it is still needed to compute other outputs (#138 )?
This comment here confirms my suspicion of why the buffer has a +1 length: https://github.com/pymc-devs/pytensor/blob/d7c787555ffc51b64c40779d7c2d20dbea32c96f/pytensor/scan/op.py#L33-L42
It's to avoid overwriting the input too soon. This is the same reason why we only implemented Inplace for Composites with a single output: https://github.com/pymc-devs/pytensor/issues/138
So we can be more aggressive if there is a single output, but perhaps not if there are more, at least not without inspecting the inner graph carefully.
This concern shouldn't apply to the JIT backends though, as there's no mechanism to provide the output buffer to the inner functions in those backends, so the optimization can be more aggressive
The buffer size is tweaked for the JIT backends in #1281
The other issue with repeated computations is more tricky. There's no rewrite to identify when a trace is identical to the other. For this to be the case it must have the same initial state and the same output expression.
Here is a more direct example:
import pytensor
import pytensor.tensor as pt
def update(x, y):
return pt.exp(x) - 5, pt.exp(y) - 5
x_init = pt.vector("x_init", shape=(7,))
[sx, sy], updates = pytensor.scan(update, outputs_info=[x_init, x_init], n_steps=10)
with pytensor.config.change_flags(optimizer_verbose=False):
func = pytensor.function([x_init], [sx, sy])
func.dprint(print_shape=True, print_op_info=True)
# Subtensor{start:} [id A] shape=(?, 7) 8
# ├─ Scan{scan_fn, while_loop=False, inplace=all}.0 [id B] shape=(?, 7) 6 (outer_out_sit_sot-0)
# │ ├─ 10 [id C] shape=() (n_steps)
# │ ├─ SetSubtensor{:stop} [id D] shape=(11, 7) 4 (outer_in_sit_sot-0)
# │ │ ├─ AllocEmpty{dtype='float64'} [id E] shape=(11, 7) 1
# │ │ │ ├─ 11 [id F] shape=()
# │ │ │ └─ 7 [id G] shape=()
# │ │ ├─ SpecifyShape [id H] shape=(1, 7) 3
# │ │ │ ├─ Unbroadcast{0} [id I] shape=(?, 7) 2
# │ │ │ │ └─ ExpandDims{axis=0} [id J] shape=(1, 7) 0
# │ │ │ │ └─ x_init [id K] shape=(7,)
# │ │ │ ├─ 1 [id L] shape=()
# │ │ │ └─ 7 [id M] shape=()
# │ │ └─ 1 [id N] shape=()
# │ └─ DeepCopyOp [id O] shape=(11, 7) 5 (outer_in_sit_sot-1)
# │ └─ SetSubtensor{:stop} [id D] shape=(11, 7) 4
# │ └─ ···
# └─ 1 [id N] shape=()
# Subtensor{start:} [id P] shape=(?, 7) 7
# ├─ Scan{scan_fn, while_loop=False, inplace=all}.1 [id B] shape=(?, 7) 6 (outer_out_sit_sot-1)
# │ └─ ···
# └─ 1 [id N] shape=()
# Inner graphs:
# Scan{scan_fn, while_loop=False, inplace=all} [id B]
# ← Composite{(-5.0 + exp(i0))} [id Q] shape=(7,) (inner_out_sit_sot-0)
# └─ *0-<Vector(float64, shape=(7,))> [id R] shape=(7,) -> [id D] (inner_in_sit_sot-0)
# ← Composite{(-5.0 + exp(i0))} [id S] shape=(7,) (inner_out_sit_sot-1)
# └─ *1-<Vector(float64, shape=(7,))> [id T] shape=(7,) -> [id O] (inner_in_sit_sot-1)
PyTensor knows the initial state is the same, so much it has to introduce a DeepCopy when it adds the inplace optimization to Scan itself. But it never tries to see if the inner graph for repeated inputs is identical.
It wouldn't be too hard to implement such a rewrite, but is it common enough for us to bother?