pytensor
pytensor copied to clipboard
Implement `pt.pad`
Description
Implement pt.pad, following the np.pad API with feature parity.
Very preliminary draft, uploading it in this state so I can ask @ricardoV94 to look at the _linear_ramp_pad function and tell me if I'm missing something obvious related to shapes. It should follow numpy.lib.arraypad._get_linear_ramps. Also the reflection pad uses a scan, curious if we can avoid that somehow or if we think it will be no big deal (probably the 2nd).
Also I'm not sure where to put this. I put it in tensor/basic but it might be better in tensor/extra_ops?
Related Issue
- [ ] Closes #743
- [ ] Related to #548
Checklist
- [ ] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [ ] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
What about a padding.py file?
Sure, I'll make a new file. It's just not my default. I agree it doesn't belong in basic.
Not quite 1:1 on numpy features but close. The more exotic padding schemes I would need more time to understand.
Still needs jax/numba overloads, but these should be very trivial.
Codecov Report
Attention: Patch coverage is 94.30380% with 18 lines in your changes missing coverage. Please review.
Project coverage is 81.48%. Comparing base (
c6d85d1) to head (bbeb300).
Additional details and impacted files
@@ Coverage Diff @@
## main #748 +/- ##
==========================================
+ Coverage 81.38% 81.48% +0.09%
==========================================
Files 172 174 +2
Lines 46868 47166 +298
Branches 11423 11471 +48
==========================================
+ Hits 38145 38434 +289
- Misses 6542 6548 +6
- Partials 2181 2184 +3
| Files | Coverage Δ | |
|---|---|---|
| pytensor/link/jax/dispatch/__init__.py | 100.00% <100.00%> (ø) |
|
| pytensor/link/jax/dispatch/pad.py | 100.00% <100.00%> (ø) |
|
| pytensor/tensor/subtensor.py | 89.31% <100.00%> (+0.11%) |
:arrow_up: |
| pytensor/tensor/pad.py | 97.14% <97.14%> (ø) |
|
| pytensor/tensor/extra_ops.py | 87.64% <80.95%> (-0.98%) |
:arrow_down: |
Draft of the JAX overload. Need your input on the Pad OpFromGraph. I needed some way to hang on to the keyword arguments.
It seems like there might be a difference between how JAX and numpy handle mode=mean padding, because tests pass against numpy but not against JAX. I'll investigate more carefully, but it might be a JAX bug (I doubt this padding mode is used ever)
I also think my loopy pads (symmetric, wrap) need to be redone, because they are failing a new test that arbitrarily pads every dimension of an nd input differently. So all that probably needs a re-design from the ground up.
I also think my loopy pads (symmetric, wrap) need to be redone, because they are failing a new test that arbitrarily pads every dimension of an nd input differently. So all that probably needs a re-design from the ground up.
You may need an operation per dimension
Regarding JAX do you need to implement a specific dispatch? For instance for the einsum I don't think we'll need because the OFG expression will be as good as what they do internally (since we copied it from them)
No idea on the JAX dispatch. I just assumed I should.
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
Closes #923
Some static shape information is getting lost between the scans, so the JAX test is failing now. Need to inspect the graphs and figure out what is going on :(
Not shocked. If JAX is fussy we may keep it in an OFG for it and inline in the remaining backends
Don't understand why the doctest for pad is failing
Don't understand why the doctest for
padis failing
It says there is an output that was not expected. If you have a print somewhere, you need to always test it afterward
You can also run doctest locally btw
Something like pytest --doctest-modules pytensor/tensor/pad.py --verbose
We should open a follow up issue for performance. With the reshape and concatenation, we're doing a lot of copies. We should see how much better it would be to have scans with set_subtensors like you tried halfway.
I kind of just want to skip the segfault test and come back to it later. I am trying to debug, but not really sure what's going on. It runs fine when NUMBA_DISABLE_NJIT flag is set. My suspicion is an out-of-range index, but I can't reproduce the error in an NB. I'd be fine just removing the numba tests all-together and considering it unsupported, but every other mode passes. So idk. I just want to get this over the finish line.