pytensor
pytensor copied to clipboard
Reshape should take each shape dimension as a separate input
Description
Reshape has only two inputs, x, and a vector of the output shape. This is cumbersome because many times we want to analyze the individual dimensions to rewrite Reshape as expand_dims
or get rid of useless Reshape. Also the Reshape Op needs to be parametrized with the output length, because historically we didn't have static shapes, and couldn't always guess how many entries the shape vector had.
Most times Reshape is used to concatenate dimensions, so we end up with stuff like [x.shape[0], ..., x.shape[n] * x.shape[m], ..., x.shape[-1]]
, wrapped in a MakeVector. This makes Resahpe rewrites harder because they have to handle the case where things are joined in a MakeVector or may have been constant folded into a single tensor.
https://github.com/pymc-devs/pytensor/blob/bf73f8a06be2adf1d30e4f59e30c2dfa49c5204e/pytensor/tensor/rewriting/shape.py#L921-L926
SpecifyShape already works with a variable number of inputs and we haven't any trouble with it.