pytensor
pytensor copied to clipboard
Avoid explicit broadcasting of indices in Advanced[Inc]Subtensor
Description
Advanced indexing broacast indices implicitly, so in the following case there's no reason to allocate several ones:
import pytensor
import pytensor.tensor as pt
x = pt.matrix("x")
out = x[pt.arange(x.shape[0]), pt.ones(x.shape[0], dtype=int)]
fn = pytensor.function([x], out)
fn.dprint()
# AdvancedSubtensor [id A] 3
# ├─ x [id B]
# ├─ ARange{dtype='int64'} [id C] 2
# │ ├─ 0 [id D]
# │ ├─ Shape_i{0} [id E] 0
# │ │ └─ x [id B]
# │ └─ 1 [id F]
# └─ Alloc [id G] 1
# ├─ 1 [id H]
# └─ Shape_i{0} [id E] 0
# └─ ···
We already do this optimization for the y value in IncSubtensor with local_useless_inc_subtensor_alloc:
https://github.com/pymc-devs/pytensor/blob/11218cf44b0c11e5847d20964327059d57322fc5/pytensor/tensor/rewriting/subtensor.py#L1275-L1279