pytensor
pytensor copied to clipboard
Restore fast convolution Ops, rewrites, and docs
Description
When the nnet
sub-module was depreciated, the old theano convolution functions went into the trashcan, along with the associated docs. This was partially reverted, but the docs and the "efficient" versions of the Ops are still missing.
There is interest in convolutions on the discourse from time-to-time, so it's worth talking about these.
This PR is a draft because I just restored everything. To be frank this part of the library seems somewhat bloated, and would probably be better served by an overhaul than by just restoring this code. For example, there is commentary about CUDA kernels in the docstrings -- I assume this is from the days when thenao was trying to target GPUs with pycuda. We use JAX these days, and I wager these Ops don't compile to JAX at all (though I haven't tried), or numba for that matter...
It's marked as draft for 2 reasons:
- It might be better to just re-implement convolution, be it 1d or 2d, in a manner that more closely matches the numpy and scipy signatures, just using pure Pytensor Ops derived from scan. For batches of convolutional filters, I'm pretty sure we can just use Blockwise.
- Even if we think this existing code is great for CPU optimization (which probably remains the most common use case for our users), I definitely restored too much stuff. Do we really want all 3d convolutions? Also there's tons of references to "use this one not that one" -- I haven't taken the time to figure out which on is the good one. If we want to keep that, we should just rip all the other ones out
Comments requested.
Related Issue
- [x] Closes #305
- [x] Closes #275
- [ ] Related issue (not closed by this PR): #
Checklist
- [x] Checked that the pre-commit linting/style checks pass
- [x] Included tests that prove the fix is effective or that the new feature works
- [x] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [x] New feature / enhancement
- [ ] Bug fix
- [x] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
+1 Follow numpy/scipy API would be the best. Part of the reason we usually do that is it saves us from having to think about it, and users from having to learn something new.
So Scipy signal seems to have convolve/convolve2d and correlate (I saw we have corr something?) and they are already batched (at least convolve is)? JAX also followed their API: https://jax.readthedocs.io/en/latest/notebooks/convolutions.html
I would suggest trying to implement an Op like the abstract ones, perhaps way simpler that corresponds to this API, and then specialize via rewrites to the fancy ones. Here if you think one of them is way less relevant feel free to kick it out.
I would focus first on mapping to JAX (and numba if they have anything) and if some of these cases have a nice correspondence to the old C Ops all the better, if not one more reason to kick them out.
Also I see they had old logic about shape inference for the C code, we should update it to use new static shape stuff