pytensor
pytensor copied to clipboard
Add missing numba ops
Description
We're fairly close to having full Numba support, but a few important numba issues are missing. This is an incomplete list that we should complete and then make a push to add them.
Here is a tutorial on how to add them: https://pytensor.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html
- [ ] Advanced(Inc)Subtensor
- [x] SolveTriangular
- [x] Logdet https://github.com/pymc-devs/pytensor/pull/172
- [ ] (Probably) other linalg stuff
- [ ] Broadcasting in Advanced(Inc)Subtensor1
- [x] LogSumExp (https://github.com/aesara-devs/aesara/issues/404 for code example)
- [ ] aesara/tensor/basic.py: Flatten(COp)
- [x] erfcx
In aesara this was being tracked in this milestone. The particular COps were being tracked in this issue. There's also the very important fact that we need to use numba's numpy compatible random Generator objects.
I think replacing C is too lofty of a goal for now. Instead, we should focus on making all of PyMC work with numba.
erfcx is also missing
erfcxis also missing
Doesn't it just default to the Python implementation?
LogSumExp (Add numba implementations for LogSumExp (reference implementation exists) aesara-devs/aesara#404 for code example)
We don't have a LogSumExp last time I checked. It's built from other Ops, so we shouldn't need to implement it
I want to shill the work I did on numba links for several linear algebra operations again:
https://github.com/numba/numba-scipy/commit/462191b4f745ed260056c534e28e8e0ba1a743a5
A core problem is that the numba.np.linalg module simply doesn't have links to several LAPACK functions we care about, including (but not limited to) SolveTriangular. There's also zero support for scipy.linalg functions. I show in that code it's not a big deal to write the hooks, but it's also not clear (to me anyway) where they should go in the code base. Numba-scipy is the most natural place, but I don't know if it's being actively developed beyond the scipy.experimental module.
@jessegrabowski Looks interesting! You're just linking to a commit, is that in main or is that part of an outstanding PR? What's the method you propose we integrate this?
It's an outstanding PR that didn't garner any attention.
I have no idea the best way to integrate it. I use this code in one of my own projects, where I just shoved it into a sub-module and added an entry point for the numba overloads into my setup.py. That worked fine and lets me use the overloaded functions under @njit decorators. I don't think it's a very "principled" approach, though.
@jessegrabowski I think we could just put that code in linker/numba/dispatch/linalg.py. We don't need to override the scipy functions (that would change the behavior of completely unrelated code, just because someone imports pytensor), but just provide our own linalg functions for the ops.
erfcx works for me, that was added here: #46
logsumexp is currently build from other parts:
Maybe we could improve performance of this further, but for now I think this is fine.
Hi @twiecki! When it comes to Det/Logdet from the list it looks that Det is already available, is that right?
https://github.com/pymc-devs/pytensor/blob/38dc6c9f60c45bf8d00b7201bc64139ea88c0132/pytensor/link/numba/dispatch/nlinalg.py#L48
By Logdet do you mean slogdet?
https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html#numpy.linalg.slogdet
[EDIT] Opened a WIP PR for the latter: https://github.com/pymc-devs/pytensor/pull/172
@mtsokol That looks right -- thanks for opening a PR!
Some cases of AdvancedSubtensor can be supported by clever reshaping, and indexing based on strides.
Snippet we were using sometime ago, and pasted here completely out of context:
x, y = design_matrix.nonzero()
*s1, s2, s3 = result.shape
return result.reshape(*s1, s2*s3)[..., x*s3 + y]