pytensor
pytensor copied to clipboard
Recognize `dot` from naive sum of broadcasted muls
Description
Brought up in https://github.com/pymc-devs/pytensor/pull/858
import pytensor
import pytensor.tensor as pt
a = pt.matrix("a", shape=(200, 300))
b = pt.matrix("b", shape=(300, 400))
dot = (a[:, :, None] * b).sum(1)
fn = pytensor.function([a, b], dot)
pytensor.dprint(fn, print_type=True)
# Sum{axis=1} [id A] <Matrix(float64, shape=(200, 400))> 3
# └─ Mul [id B] <Tensor3(float64, shape=(200, 300, 400))> 2
# ├─ ExpandDims{axis=2} [id C] <Tensor3(float64, shape=(200, 300, 1))> 1
# │ └─ a [id D] <Matrix(float64, shape=(200, 300))>
# └─ ExpandDims{axis=0} [id E] <Tensor3(float64, shape=(1, 300, 400))> 0
# └─ b [id F] <Matrix(float64, shape=(300, 400))>
fn_dot = pytensor.function([a, b], a @ b)
print(); pytensor.dprint(fn_dot, print_type=True)
# Dot22 [id A] <Matrix(float64, shape=(200, 400))> 0
# ├─ a [id B] <Matrix(float64, shape=(200, 300))>
# └─ b [id C] <Matrix(float64, shape=(300, 400))>
a_test = np.random.normal(size=a.type.shape)
b_test = np.random.normal(size=b.type.shape)
np.testing.assert_allclose(fn(a_test, b_test), fn_dot(a_test, b_test))
%timeit fn(a_test, b_test) # 70.9 ms ± 1.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit fn_dot(a_test, b_test) # 861 µs ± 148 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)