pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Add missing numba ops

Open twiecki opened this issue 2 years ago • 14 comments

Description

We're fairly close to having full Numba support, but a few important numba issues are missing. This is an incomplete list that we should complete and then make a push to add them.

Here is a tutorial on how to add them: https://pytensor.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html

  • [ ] Advanced(Inc)Subtensor
  • [x] SolveTriangular
  • [x] Logdet https://github.com/pymc-devs/pytensor/pull/172
  • [ ] (Probably) other linalg stuff
  • [ ] Broadcasting in Advanced(Inc)Subtensor1
  • [x] LogSumExp (https://github.com/aesara-devs/aesara/issues/404 for code example)
  • [ ] aesara/tensor/basic.py: Flatten(COp)
  • [x] erfcx

twiecki avatar Jan 03 '23 05:01 twiecki

In aesara this was being tracked in this milestone. The particular COps were being tracked in this issue. There's also the very important fact that we need to use numba's numpy compatible random Generator objects.

lucianopaz avatar Jan 03 '23 10:01 lucianopaz

I think replacing C is too lofty of a goal for now. Instead, we should focus on making all of PyMC work with numba.

twiecki avatar Jan 03 '23 10:01 twiecki

erfcx is also missing

fonnesbeck avatar Jan 03 '23 12:01 fonnesbeck

erfcx is also missing

Doesn't it just default to the Python implementation?

ricardoV94 avatar Jan 03 '23 15:01 ricardoV94

LogSumExp (Add numba implementations for LogSumExp (reference implementation exists) aesara-devs/aesara#404 for code example)

We don't have a LogSumExp last time I checked. It's built from other Ops, so we shouldn't need to implement it

ricardoV94 avatar Jan 03 '23 15:01 ricardoV94

I want to shill the work I did on numba links for several linear algebra operations again:

https://github.com/numba/numba-scipy/commit/462191b4f745ed260056c534e28e8e0ba1a743a5

A core problem is that the numba.np.linalg module simply doesn't have links to several LAPACK functions we care about, including (but not limited to) SolveTriangular. There's also zero support for scipy.linalg functions. I show in that code it's not a big deal to write the hooks, but it's also not clear (to me anyway) where they should go in the code base. Numba-scipy is the most natural place, but I don't know if it's being actively developed beyond the scipy.experimental module.

jessegrabowski avatar Jan 03 '23 17:01 jessegrabowski

@jessegrabowski Looks interesting! You're just linking to a commit, is that in main or is that part of an outstanding PR? What's the method you propose we integrate this?

twiecki avatar Jan 03 '23 17:01 twiecki

It's an outstanding PR that didn't garner any attention.

I have no idea the best way to integrate it. I use this code in one of my own projects, where I just shoved it into a sub-module and added an entry point for the numba overloads into my setup.py. That worked fine and lets me use the overloaded functions under @njit decorators. I don't think it's a very "principled" approach, though.

jessegrabowski avatar Jan 03 '23 17:01 jessegrabowski

@jessegrabowski I think we could just put that code in linker/numba/dispatch/linalg.py. We don't need to override the scipy functions (that would change the behavior of completely unrelated code, just because someone imports pytensor), but just provide our own linalg functions for the ops.

aseyboldt avatar Jan 03 '23 18:01 aseyboldt

erfcx works for me, that was added here: #46

aseyboldt avatar Jan 03 '23 19:01 aseyboldt

logsumexp is currently build from other parts: image Maybe we could improve performance of this further, but for now I think this is fine.

aseyboldt avatar Jan 03 '23 20:01 aseyboldt

Hi @twiecki! When it comes to Det/Logdet from the list it looks that Det is already available, is that right? https://github.com/pymc-devs/pytensor/blob/38dc6c9f60c45bf8d00b7201bc64139ea88c0132/pytensor/link/numba/dispatch/nlinalg.py#L48

By Logdet do you mean slogdet?
https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html#numpy.linalg.slogdet

[EDIT] Opened a WIP PR for the latter: https://github.com/pymc-devs/pytensor/pull/172

mtsokol avatar Jan 04 '23 17:01 mtsokol

@mtsokol That looks right -- thanks for opening a PR!

twiecki avatar Jan 05 '23 02:01 twiecki

Some cases of AdvancedSubtensor can be supported by clever reshaping, and indexing based on strides.

Snippet we were using sometime ago, and pasted here completely out of context:

x, y = design_matrix.nonzero()
*s1, s2, s3 = result.shape
return result.reshape(*s1, s2*s3)[..., x*s3 + y]

ricardoV94 avatar Jan 11 '24 07:01 ricardoV94