pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Implement several subtensor lift rewrites

Open ricardoV94 opened this issue 11 months ago • 1 comments

This allows reducing computations on batch dimensions by lifting simple indexing operations closer to the inputs.

An obvious example is:

import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode

mode = get_default_mode()

x = pt.matrix("x", shape=(512, 512))
x_test = np.random.normal(size=x.type.shape)

x_sum = x.sum(axis=1)
out = x_sum[0]

fn_before = pytensor.function([x], out, mode=mode.excluding("local_subtensor_of_reduce"))
fn_before.dprint(print_type=True)
%timeit fn_before(x_test)
# Subtensor{i} [id A] <Scalar(float64, shape=())> 1
#  ├─ Sum{axis=1} [id B] <Vector(float64, shape=(512,))> 0
#  │  └─ x [id C] <Matrix(float64, shape=(512, 512))>
#  └─ 0 [id D] <uint8>
# 762 μs ± 7.55 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

print()
fn_after = pytensor.function([x], out, mode=mode)
fn_after.dprint(print_type=True)
%timeit fn_after(x_test)
# Sum{axes=None} [id A] <Scalar(float64, shape=())> 1
#  └─ Subtensor{i} [id B] <Vector(float64, shape=(512,))> 0
#     ├─ x [id C] <Matrix(float64, shape=(512, 512))>
#     └─ 0 [id D] <uint8>
# 5.26 μs ± 86 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

📚 Documentation preview 📚: https://pytensor--1158.org.readthedocs.build/en/1158/

ricardoV94 avatar Jan 20 '25 17:01 ricardoV94

Codecov Report

:x: Patch coverage is 90.97561% with 37 lines in your changes missing coverage. Please review. :white_check_mark: Project coverage is 82.07%. Comparing base (5335a68) to head (4b545cf). :warning: Report is 176 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/subtensor_lift.py 90.77% 17 Missing and 20 partials :warning:
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1158      +/-   ##
==========================================
+ Coverage   82.02%   82.07%   +0.05%     
==========================================
  Files         207      208       +1     
  Lines       49301    49517     +216     
  Branches     8747     8785      +38     
==========================================
+ Hits        40440    40642     +202     
- Misses       6695     6702       +7     
- Partials     2166     2173       +7     
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/subtensor.py 90.08% <100.00%> (+0.49%) :arrow_up:
pytensor/tensor/rewriting/subtensor_lift.py 90.77% <90.77%> (ø)
:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Mar 11 '25 14:03 codecov[bot]