pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Simplify subtensor shape inference

Open ricardoV94 opened this issue 9 months ago • 0 comments

This PR simplifies the expressions returned by Subtensor.infer_shape. Some examples:

This PR also implements some basic minmax algebra rewrites to handle constraints implied by slice inputs. The following shape graph is now as succinct as it can be without knowing the shape of x (in which case it would be constant ofc)

import pytensor
import pytensor.tensor as pt

x = pt.vector("x")
y = x[1:-1][1:-1][1:-1].shape[0]
fn = pytensor.function([x], y)
fn.dprint()
Composite{maximum(0, (-6 + i0))} [id A] 1
 └─ Shape_i{0} [id B] 0
    └─ x [id C]

Inner graphs:

Composite{maximum(0, (-6 + i0))} [id A]
 ← maximum [id D] 'o0'
    ├─ 0 [id E]
    └─ add [id F]
       ├─ -6 [id G]
       └─ i0 [id H]

If you dare check what was the output before:

Composite{...} [id A] 1
 └─ Shape_i{0} [id B] 0
    └─ x [id C]
Inner graphs:
Composite{...} [id A]
 ← sub [id D] 'o0'
    ├─ Switch [id E] 't31'
    │  ├─ LT [id F]
    │  │  ├─ Switch [id G] 't17'
    │  │  │  ├─ GE [id H]
    │  │  │  │  ├─ Switch [id I] 't2'
    │  │  │  │  │  ├─ LT [id J]
    │  │  │  │  │  │  ├─ add [id K] 't37'
    │  │  │  │  │  │  │  ├─ -1 [id L]
    │  │  │  │  │  │  │  └─ sub [id M] 't34'
    │  │  │  │  │  │  │     ├─ Switch [id N] 't33'
    │  │  │  │  │  │  │     │  ├─ LT [id O]
    │  │  │  │  │  │  │     │  │  ├─ Switch [id P] 't19'
    │  │  │  │  │  │  │     │  │  │  ├─ GE [id Q]
    │  │  │  │  │  │  │     │  │  │  │  ├─ Switch [id R] 't5'
    │  │  │  │  │  │  │     │  │  │  │  │  ├─ LT [id S]
    │  │  │  │  │  │  │     │  │  │  │  │  │  ├─ add [id T] 't39'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │  ├─ -1 [id L]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │  └─ sub [id U] 't36'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     ├─ Switch [id V] 't35'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  ├─ LT [id W]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  ├─ Switch [id X] 't21'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  ├─ GE [id Y]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  ├─ Switch [id Z] 't12'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │  ├─ LT [id BA]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │  │  ├─ add [id BB] 't4'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │  │  │  ├─ -1 [id L]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │  │  │  └─ i0 [id BC]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │  │  └─ 0 [id BD]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │  ├─ -1 [id BE]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │  └─ add [id BB] 't4'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  │     └─ ···
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  │  └─ i0 [id BC]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  ├─ i0 [id BC]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │  └─ Switch [id Z] 't12'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  │     └─ ···
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  │  └─ 0 [id BD]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  ├─ 0 [id BD]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │  └─ Switch [id X] 't21'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     │     └─ ···
    │  │  │  │  │  │  │     │  │  │  │  │  │  │     └─ Switch [id BF]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        ├─ LT [id BG]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  ├─ Switch [id BH] 't15'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  ├─ LT [id BI]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  │  ├─ Switch [id BJ] 't45'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  │  │  ├─ GE [id BK]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  │  │  │  ├─ 1 [id BL]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  │  │  │  └─ i0 [id BC]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  │  │  ├─ i0 [id BC]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  │  │  └─ 1 [id BL]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  │  └─ 0 [id BD]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  ├─ 0 [id BD]
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │  └─ Switch [id BJ] 't45'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  │     └─ ···
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  └─ Switch [id V] 't35'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │     └─ ···
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        ├─ Switch [id BH] 't15'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        │  └─ ···
    │  │  │  │  │  │  │     │  │  │  │  │  │  │        └─ Switch [id V] 't35'
    │  │  │  │  │  │  │     │  │  │  │  │  │  │           └─ ···
    │  │  │  │  │  │  │     │  │  │  │  │  │  └─ 0 [id BD]
    │  │  │  │  │  │  │     │  │  │  │  │  ├─ -1 [id BE]
    │  │  │  │  │  │  │     │  │  │  │  │  └─ add [id T] 't39'
    │  │  │  │  │  │  │     │  │  │  │  │     └─ ···
    │  │  │  │  │  │  │     │  │  │  │  └─ sub [id U] 't36'
    │  │  │  │  │  │  │     │  │  │  │     └─ ···
    │  │  │  │  │  │  │     │  │  │  ├─ sub [id U] 't36'
    │  │  │  │  │  │  │     │  │  │  │  └─ ···
    │  │  │  │  │  │  │     │  │  │  └─ Switch [id R] 't5'
    │  │  │  │  │  │  │     │  │  │     └─ ···
    │  │  │  │  │  │  │     │  │  └─ 0 [id BD]
    │  │  │  │  │  │  │     │  ├─ 0 [id BD]
    │  │  │  │  │  │  │     │  └─ Switch [id P] 't19'
    │  │  │  │  │  │  │     │     └─ ···
    │  │  │  │  │  │  │     └─ Switch [id BM]
    │  │  │  │  │  │  │        ├─ LT [id BN]
    │  │  │  │  │  │  │        │  ├─ Switch [id BO] 't13'
    │  │  │  │  │  │  │        │  │  ├─ LT [id BP]
    │  │  │  │  │  │  │        │  │  │  ├─ Switch [id BQ] 't43'
    │  │  │  │  │  │  │        │  │  │  │  ├─ GE [id BR]
    │  │  │  │  │  │  │        │  │  │  │  │  ├─ 1 [id BL]
    │  │  │  │  │  │  │        │  │  │  │  │  └─ sub [id U] 't36'
    │  │  │  │  │  │  │        │  │  │  │  │     └─ ···
    │  │  │  │  │  │  │        │  │  │  │  ├─ sub [id U] 't36'
    │  │  │  │  │  │  │        │  │  │  │  │  └─ ···
    │  │  │  │  │  │  │        │  │  │  │  └─ 1 [id BL]
    │  │  │  │  │  │  │        │  │  │  └─ 0 [id BD]
    │  │  │  │  │  │  │        │  │  ├─ 0 [id BD]
    │  │  │  │  │  │  │        │  │  └─ Switch [id BQ] 't43'
    │  │  │  │  │  │  │        │  │     └─ ···
    │  │  │  │  │  │  │        │  └─ Switch [id N] 't33'
    │  │  │  │  │  │  │        │     └─ ···
    │  │  │  │  │  │  │        ├─ Switch [id BO] 't13'
    │  │  │  │  │  │  │        │  └─ ···
    │  │  │  │  │  │  │        └─ Switch [id N] 't33'
    │  │  │  │  │  │  │           └─ ···
    │  │  │  │  │  │  └─ 0 [id BD]
    │  │  │  │  │  ├─ -1 [id BE]
    │  │  │  │  │  └─ add [id K] 't37'
    │  │  │  │  │     └─ ···
    │  │  │  │  └─ sub [id M] 't34'
    │  │  │  │     └─ ···
    │  │  │  ├─ sub [id M] 't34'
    │  │  │  │  └─ ···
    │  │  │  └─ Switch [id I] 't2'
    │  │  │     └─ ···
    │  │  └─ 0 [id BD]
    │  ├─ 0 [id BD]
    │  └─ Switch [id G] 't17'
    │     └─ ···
    └─ Switch [id BS]
       ├─ LT [id BT]
       │  ├─ Switch [id BU] 't9'
       │  │  ├─ LT [id BV]
       │  │  │  ├─ Switch [id BW] 't42'
       │  │  │  │  ├─ GE [id BX]
       │  │  │  │  │  ├─ 1 [id BL]
       │  │  │  │  │  └─ sub [id M] 't34'
       │  │  │  │  │     └─ ···
       │  │  │  │  ├─ sub [id M] 't34'
       │  │  │  │  │  └─ ···
       │  │  │  │  └─ 1 [id BL]
       │  │  │  └─ 0 [id BD]
       │  │  ├─ 0 [id BD]
       │  │  └─ Switch [id BW] 't42'
       │  │     └─ ···
       │  └─ Switch [id E] 't31'
       │     └─ ···
       ├─ Switch [id BU] 't9'
       │  └─ ···
       └─ Switch [id E] 't31'
          └─ ···

This is rather important as slices are the most common source of dynamic shapes. We want to make sure these are clever.

This starts working towards fixing #112, it already simplifies standard slice shapes well enough by avoiding the canonical_form_slice monster. The next step is for the rewrite to also avoid it, although the number of parametrized combinations between two adjacent slices is no small feat. We may not need to merge all subtensors, after all a slice is a pretty cheap operation.

The most important thing is to have a good inference on the shape of multiple slices, as the scan save memory rewrite uses that to decide how many steps (cute but not critical) and how many entries to store in the buffer (rather important).

The new rewrites themselves also suggest a way we can start doing more clever type inference stuff. The basics for knowing upper/lower bounding are here. It probably makes sense to offer this as a feature, but for now they are just used two of the new rewrites.

Related Issue

  • [x] Related to #112

Checklist

Type of change

  • [ ] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

📚 Documentation preview 📚: https://pytensor--1299.org.readthedocs.build/en/1299/

ricardoV94 avatar Mar 17 '25 18:03 ricardoV94