pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Implement `pt.pad`

Open jessegrabowski opened this issue 1 year ago • 2 comments

Description

Implement pt.pad, following the np.pad API with feature parity.

Very preliminary draft, uploading it in this state so I can ask @ricardoV94 to look at the _linear_ramp_pad function and tell me if I'm missing something obvious related to shapes. It should follow numpy.lib.arraypad._get_linear_ramps. Also the reflection pad uses a scan, curious if we can avoid that somehow or if we think it will be no big deal (probably the 2nd).

Also I'm not sure where to put this. I put it in tensor/basic but it might be better in tensor/extra_ops?

Related Issue

  • [ ] Closes #743
  • [ ] Related to #548

Checklist

Type of change

  • [ ] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

jessegrabowski avatar May 04 '24 16:05 jessegrabowski

What about a padding.py file?

ricardoV94 avatar May 04 '24 17:05 ricardoV94

Sure, I'll make a new file. It's just not my default. I agree it doesn't belong in basic.

jessegrabowski avatar May 05 '24 01:05 jessegrabowski

Not quite 1:1 on numpy features but close. The more exotic padding schemes I would need more time to understand.

Still needs jax/numba overloads, but these should be very trivial.

jessegrabowski avatar May 11 '24 15:05 jessegrabowski

Codecov Report

Attention: Patch coverage is 94.30380% with 18 lines in your changes missing coverage. Please review.

Project coverage is 81.48%. Comparing base (c6d85d1) to head (bbeb300).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #748      +/-   ##
==========================================
+ Coverage   81.38%   81.48%   +0.09%     
==========================================
  Files         172      174       +2     
  Lines       46868    47166     +298     
  Branches    11423    11471      +48     
==========================================
+ Hits        38145    38434     +289     
- Misses       6542     6548       +6     
- Partials     2181     2184       +3     
Files Coverage Δ
pytensor/link/jax/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/jax/dispatch/pad.py 100.00% <100.00%> (ø)
pytensor/tensor/subtensor.py 89.31% <100.00%> (+0.11%) :arrow_up:
pytensor/tensor/pad.py 97.14% <97.14%> (ø)
pytensor/tensor/extra_ops.py 87.64% <80.95%> (-0.98%) :arrow_down:

... and 4 files with indirect coverage changes

codecov[bot] avatar May 11 '24 15:05 codecov[bot]

Draft of the JAX overload. Need your input on the Pad OpFromGraph. I needed some way to hang on to the keyword arguments.

It seems like there might be a difference between how JAX and numpy handle mode=mean padding, because tests pass against numpy but not against JAX. I'll investigate more carefully, but it might be a JAX bug (I doubt this padding mode is used ever)

I also think my loopy pads (symmetric, wrap) need to be redone, because they are failing a new test that arbitrarily pads every dimension of an nd input differently. So all that probably needs a re-design from the ground up.

jessegrabowski avatar May 18 '24 15:05 jessegrabowski

I also think my loopy pads (symmetric, wrap) need to be redone, because they are failing a new test that arbitrarily pads every dimension of an nd input differently. So all that probably needs a re-design from the ground up.

You may need an operation per dimension

ricardoV94 avatar May 20 '24 07:05 ricardoV94

Regarding JAX do you need to implement a specific dispatch? For instance for the einsum I don't think we'll need because the OFG expression will be as good as what they do internally (since we copied it from them)

ricardoV94 avatar May 20 '24 07:05 ricardoV94

No idea on the JAX dispatch. I just assumed I should.

jessegrabowski avatar May 20 '24 07:05 jessegrabowski

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Closes #923

williambdean avatar Jul 12 '24 07:07 williambdean

Some static shape information is getting lost between the scans, so the JAX test is failing now. Need to inspect the graphs and figure out what is going on :(

jessegrabowski avatar Jul 13 '24 12:07 jessegrabowski

Not shocked. If JAX is fussy we may keep it in an OFG for it and inline in the remaining backends

ricardoV94 avatar Jul 13 '24 12:07 ricardoV94

Don't understand why the doctest for pad is failing

jessegrabowski avatar Jul 13 '24 15:07 jessegrabowski

Don't understand why the doctest for pad is failing

It says there is an output that was not expected. If you have a print somewhere, you need to always test it afterward

ricardoV94 avatar Jul 13 '24 16:07 ricardoV94

You can also run doctest locally btw

ricardoV94 avatar Jul 13 '24 16:07 ricardoV94

Something like pytest --doctest-modules pytensor/tensor/pad.py --verbose

ricardoV94 avatar Jul 13 '24 16:07 ricardoV94

We should open a follow up issue for performance. With the reshape and concatenation, we're doing a lot of copies. We should see how much better it would be to have scans with set_subtensors like you tried halfway.

ricardoV94 avatar Jul 13 '24 18:07 ricardoV94

I kind of just want to skip the segfault test and come back to it later. I am trying to debug, but not really sure what's going on. It runs fine when NUMBA_DISABLE_NJIT flag is set. My suspicion is an out-of-range index, but I can't reproduce the error in an NB. I'd be fine just removing the numba tests all-together and considering it unsupported, but every other mode passes. So idk. I just want to get this over the finish line.

jessegrabowski avatar Jul 14 '24 02:07 jessegrabowski