pytensor
pytensor copied to clipboard
Reconsider use of MakeSlice and NoneConst as inputs to AdvancedIndexing
Description
These means we cannot use Blockwise / vectorize trivially (in the cases that would be valid).
import pytensor
import pytensor.tensor as pt
from pytensor.graph.replace import vectorize_graph
x = pt.matrix("x")
s = pt.scalar("s", dtype="int")
out = pt.set_subtensor(x[s:, [0, 0]], 0)
pytensor.dprint(out)
# AdvancedSetSubtensor [id A]
# ├─ x [id B]
# ├─ 0 [id C]
# ├─ MakeSlice [id D]
# │ ├─ s [id E]
# │ ├─ NoneConst{None} [id F]
# │ └─ NoneConst{None} [id F]
# └─ [0 0] [id G]
z = pt.vector("z", dtype="int")
vec_out = vectorize_graph(out, replace={s: z}) # Raises an Error
We can do it for Subtensor
because that only takes numerical inputs and keeps the information about what they represent as properties of the Op
.
import pytensor
import pytensor.tensor as pt
from pytensor.graph.replace import vectorize_graph
x = pt.vector("x")
s = pt.scalar("s", dtype="int")
out = pt.set_subtensor(x[s:], 0)
pytensor.dprint(out)
# SetSubtensor{start:} [id A]
# ├─ x [id B]
# ├─ 0 [id C]
# └─ ScalarFromTensor [id D]
# └─ s [id E]
z = pt.vector("z", dtype="int")
vec_out = vectorize_graph(out, replace={s: z})
pytensor.dprint(vec_out)
# Blockwise{SetSubtensor{start:}, (i00),(),()->(o00)} [id A]
# ├─ ExpandDims{axis=0} [id B]
# │ └─ x [id C]
# ├─ ExpandDims{axis=0} [id D]
# │ └─ 0 [id E]
# └─ Blockwise{ScalarFromTensor, ()->()} [id F]
# └─ z [id G]