pytensor
pytensor copied to clipboard
Implement several subtensor lift rewrites
This allows reducing computations on batch dimensions by lifting simple indexing operations closer to the inputs.
An obvious example is:
import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode
mode = get_default_mode()
x = pt.matrix("x", shape=(512, 512))
x_test = np.random.normal(size=x.type.shape)
x_sum = x.sum(axis=1)
out = x_sum[0]
fn_before = pytensor.function([x], out, mode=mode.excluding("local_subtensor_of_reduce"))
fn_before.dprint(print_type=True)
%timeit fn_before(x_test)
# Subtensor{i} [id A] <Scalar(float64, shape=())> 1
# ├─ Sum{axis=1} [id B] <Vector(float64, shape=(512,))> 0
# │ └─ x [id C] <Matrix(float64, shape=(512, 512))>
# └─ 0 [id D] <uint8>
# 762 μs ± 7.55 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
print()
fn_after = pytensor.function([x], out, mode=mode)
fn_after.dprint(print_type=True)
%timeit fn_after(x_test)
# Sum{axes=None} [id A] <Scalar(float64, shape=())> 1
# └─ Subtensor{i} [id B] <Vector(float64, shape=(512,))> 0
# ├─ x [id C] <Matrix(float64, shape=(512, 512))>
# └─ 0 [id D] <uint8>
# 5.26 μs ± 86 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
📚 Documentation preview 📚: https://pytensor--1158.org.readthedocs.build/en/1158/
Codecov Report
:x: Patch coverage is 90.97561% with 37 lines in your changes missing coverage. Please review.
:white_check_mark: Project coverage is 82.07%. Comparing base (5335a68) to head (4b545cf).
:warning: Report is 176 commits behind head on main.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| pytensor/tensor/rewriting/subtensor_lift.py | 90.77% | 17 Missing and 20 partials :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #1158 +/- ##
==========================================
+ Coverage 82.02% 82.07% +0.05%
==========================================
Files 207 208 +1
Lines 49301 49517 +216
Branches 8747 8785 +38
==========================================
+ Hits 40440 40642 +202
- Misses 6695 6702 +7
- Partials 2166 2173 +7
| Files with missing lines | Coverage Δ | |
|---|---|---|
| pytensor/tensor/rewriting/subtensor.py | 90.08% <100.00%> (+0.49%) |
:arrow_up: |
| pytensor/tensor/rewriting/subtensor_lift.py | 90.77% <90.77%> (ø) |
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.