pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Recognize `dot` from naive sum of broadcasted muls

Open ricardoV94 opened this issue 7 months ago • 0 comments

Description

Brought up in https://github.com/pymc-devs/pytensor/pull/858

import pytensor
import pytensor.tensor as pt

a = pt.matrix("a", shape=(200, 300))
b = pt.matrix("b", shape=(300, 400))
dot = (a[:, :, None] * b).sum(1)

fn = pytensor.function([a, b], dot)
pytensor.dprint(fn, print_type=True)
# Sum{axis=1} [id A] <Matrix(float64, shape=(200, 400))> 3
#  └─ Mul [id B] <Tensor3(float64, shape=(200, 300, 400))> 2
#     ├─ ExpandDims{axis=2} [id C] <Tensor3(float64, shape=(200, 300, 1))> 1
#     │  └─ a [id D] <Matrix(float64, shape=(200, 300))>
#     └─ ExpandDims{axis=0} [id E] <Tensor3(float64, shape=(1, 300, 400))> 0
#        └─ b [id F] <Matrix(float64, shape=(300, 400))>

fn_dot = pytensor.function([a, b], a @ b)
print(); pytensor.dprint(fn_dot, print_type=True)
# Dot22 [id A] <Matrix(float64, shape=(200, 400))> 0
#  ├─ a [id B] <Matrix(float64, shape=(200, 300))>
#  └─ b [id C] <Matrix(float64, shape=(300, 400))>

a_test = np.random.normal(size=a.type.shape)
b_test = np.random.normal(size=b.type.shape)
np.testing.assert_allclose(fn(a_test, b_test), fn_dot(a_test, b_test))

%timeit fn(a_test, b_test)  # 70.9 ms ± 1.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit fn_dot(a_test, b_test)  # 861 µs ± 148 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

ricardoV94 avatar Jun 27 '24 18:06 ricardoV94