Dimshuffle does not need input broadcastable info
Description
Make DimShuffle less concerned about it's input static broadcastable type when dropping dims, since no behavior depends on it being specified at runtime.
But there is no reason the input broadcastable type had to be known in advance, since no behavior of the Op changes with that information. It's true that dropping it only works when the runtime shape is of length 1, but that will naturally fail when we try to do it for different shape lengths.
Alternatively / in addition we could plaster specify_broadcastable during make_node. Beyond this potential extension, I see no reason why DimShuffle should be so strict.
This fixes the bug reported in #969
Related Issue
- [x] Closes #914
- [x] Closes #969
Codecov Report
Attention: Patch coverage is 88.88889% with 5 lines in your changes missing coverage. Please review.
Project coverage is 81.73%. Comparing base (
1c2bc8f) to head (2425c9a). Report is 108 commits behind head on main.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| pytensor/tensor/elemwise.py | 82.75% | 2 Missing and 3 partials :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #979 +/- ##
=======================================
Coverage 81.73% 81.73%
=======================================
Files 183 183
Lines 47734 47721 -13
Branches 11611 11601 -10
=======================================
- Hits 39016 39007 -9
+ Misses 6523 6521 -2
+ Partials 2195 2193 -2
| Files with missing lines | Coverage Δ | |
|---|---|---|
| pytensor/sparse/sandbox/sp.py | 74.61% <100.00%> (-0.20%) |
:arrow_down: |
| pytensor/tensor/basic.py | 91.54% <100.00%> (ø) |
|
| pytensor/tensor/extra_ops.py | 87.73% <100.00%> (-0.04%) |
:arrow_down: |
| pytensor/tensor/fft.py | 82.29% <100.00%> (ø) |
|
| pytensor/tensor/inplace.py | 100.00% <100.00%> (ø) |
|
| pytensor/tensor/math.py | 91.27% <100.00%> (-0.01%) |
:arrow_down: |
| pytensor/tensor/random/rewriting/jax.py | 97.08% <ø> (ø) |
|
| pytensor/tensor/rewriting/basic.py | 94.15% <100.00%> (ø) |
|
| pytensor/tensor/rewriting/elemwise.py | 91.14% <100.00%> (+0.29%) |
:arrow_up: |
| pytensor/tensor/rewriting/jax.py | 82.81% <ø> (ø) |
|
| ... and 4 more |