pytensor
pytensor copied to clipboard
DimShuffle should be happy to drop dims if they have length 1 at runtime
Description
Right now we need to add specify_shape if we want to squeeze a dimension that the user (but not PyTensor) known to be length 1. This complicates the graph slightly, and I don't see a reason for it. The only thing DimShuffle needs to know is the number of dimensions of the input which is never ambiguous. Then an missing in the pattern is a drop.
https://github.com/pymc-devs/pytensor/blob/ee4d4f71f932604c7e398c9a2bb6c6cef0d6e91f/pytensor/tensor/extra_ops.py#L602-L607
We should check nothing in the implementation fails if something was meant to be dropped but was not length 1 at runtime. If nothing fails, we can simplify DimShuffle and get rid of the useless SpecifyShape.