Optimize `Sum`s of `MakeVector`s and `Join`s
Please describe the purpose of filing this issue
Not a drastic improvement by any means, but something we can keep in mind:
reduce(at.concatenate(*tensors)) -> reduce(reduce(tensor) for tensor in tensors)
Ignoring any axis complexities
import pytensor
import pytensor.tensor as pt
import numpy as np
x = pt.vector("x")
y = pt.vector("y")
f1 = pytensor.function([x, y], pt.sum(pt.concatenate((x, y))))
f2 = pytensor.function([x, y], pt.sum((pt.sum(x), pt.sum(y))))
f3 = pytensor.function([x, y], pt.add(pt.sum(x), pt.sum(y)))
pytensor.dprint(f1)
print()
pytensor.dprint(f2)
print()
pytensor.dprint(f3)
x_val = np.random.rand(100_000)
y_val = np.random.rand(200_000)
%timeit f1(x_val, y_val)
%timeit f2(x_val, y_val)
%timeit f3(x_val, y_val)
Sum{acc_dtype=float64} [id A] '' 1
|Join [id B] '' 0
|TensorConstant{0} [id C]
|x [id D]
|y [id E]
Sum{acc_dtype=float64} [id A] '' 3
|MakeVector{dtype='float64'} [id B] '' 2
|Sum{acc_dtype=float64} [id C] '' 1
| |x [id D]
|Sum{acc_dtype=float64} [id E] '' 0
|y [id F]
Elemwise{Add}[(0, 0)] [id A] '' 2
|Sum{acc_dtype=float64} [id B] '' 1
| |x [id C]
|Sum{acc_dtype=float64} [id D] '' 0
|y [id E]
544 µs ± 27.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
270 µs ± 5.11 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
270 µs ± 8.86 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Hi! @ricardoV94. I am new to PyMC. I want to work on this. Could you please guide me further?
Updated output:
Sum{axes=None} [id A] 1
└─ Join [id B] 0
├─ 0 [id C]
├─ x [id D]
└─ y [id E]
Add [id A] 2
├─ Sum{axes=None} [id B] 1
│ └─ x [id C]
└─ Sum{axes=None} [id D] 0
└─ y [id E]
Add [id A] 2
├─ Sum{axes=None} [id B] 1
│ └─ x [id C]
└─ Sum{axes=None} [id D] 0
└─ y [id E]
360 μs ± 8.84 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
293 μs ± 4.04 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
290 μs ± 466 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
f2 and f3 are already optimized (especially get rid of MakeVector for f2). Only f1 is still stuck with the Join Op.
Things still missing. There is an optimization for sum of make_vector introduced in #346, but not for other CAReduce. We should extend it.
There's also rewrite for Sum/Prod of Join along axis0 for a join along axis 0. #888 extends it no any join axis if the reduction is on that same axis.
Either still leaves out optimizations where:
- We concatenate and then sum axes including the concatenation axis (done in #1709)
- We concatenate and sum axes, excluding the concatenation axis (still needed)
For instance:
import numpy as np
import pytensor
import pytensor.tensor as pt
x = pt.tensor3("x", shape=(128, 128, 128))
y = pt.tensor3("y", shape=(128, 128, 128))
joined = pt.join(0, x, y)
out = pt.sum(joined, axis=(1, 2))
fn = pytensor.function([x, y], out)
alt_out = pt.join(
0,
pt.sum(x, axis=(1, 2)),
pt.sum(y, axis=(1, 2)),
)
alt_fn = pytensor.function([x, y], alt_out)
x_test = np.random.normal(size=x.type.shape)
y_test = np.random.normal(size=y.type.shape)
fn.trust_input=True
alt_fn.trust_input=True
np.testing.assert_allclose(fn(x_test, y_test), alt_fn(x_test, y_test))
%timeit fn(x_test, y_test) # 19.9 ms ± 1.93 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit alt_fn(x_test, y_test) # 7.73 ms ± 618 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)