pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Dimshuffle does not need input broadcastable info

Open ricardoV94 opened this issue 1 year ago • 1 comments

Description

Make DimShuffle less concerned about it's input static broadcastable type when dropping dims, since no behavior depends on it being specified at runtime.

But there is no reason the input broadcastable type had to be known in advance, since no behavior of the Op changes with that information. It's true that dropping it only works when the runtime shape is of length 1, but that will naturally fail when we try to do it for different shape lengths.

Alternatively / in addition we could plaster specify_broadcastable during make_node. Beyond this potential extension, I see no reason why DimShuffle should be so strict.

This fixes the bug reported in #969

Related Issue

  • [x] Closes #914
  • [x] Closes #969

ricardoV94 avatar Aug 22 '24 12:08 ricardoV94

Codecov Report

Attention: Patch coverage is 88.88889% with 5 lines in your changes missing coverage. Please review.

Project coverage is 81.73%. Comparing base (1c2bc8f) to head (2425c9a). Report is 108 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/elemwise.py 82.75% 2 Missing and 3 partials :warning:
Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #979   +/-   ##
=======================================
  Coverage   81.73%   81.73%           
=======================================
  Files         183      183           
  Lines       47734    47721   -13     
  Branches    11611    11601   -10     
=======================================
- Hits        39016    39007    -9     
+ Misses       6523     6521    -2     
+ Partials     2195     2193    -2     
Files with missing lines Coverage Δ
pytensor/sparse/sandbox/sp.py 74.61% <100.00%> (-0.20%) :arrow_down:
pytensor/tensor/basic.py 91.54% <100.00%> (ø)
pytensor/tensor/extra_ops.py 87.73% <100.00%> (-0.04%) :arrow_down:
pytensor/tensor/fft.py 82.29% <100.00%> (ø)
pytensor/tensor/inplace.py 100.00% <100.00%> (ø)
pytensor/tensor/math.py 91.27% <100.00%> (-0.01%) :arrow_down:
pytensor/tensor/random/rewriting/jax.py 97.08% <ø> (ø)
pytensor/tensor/rewriting/basic.py 94.15% <100.00%> (ø)
pytensor/tensor/rewriting/elemwise.py 91.14% <100.00%> (+0.29%) :arrow_up:
pytensor/tensor/rewriting/jax.py 82.81% <ø> (ø)
... and 4 more

codecov[bot] avatar Aug 24 '24 10:08 codecov[bot]