Improve Numba caching and source generation
Looks like this wasn't given its own issue, but we need to improve the caching and way we generate source for Numba JITing.
To clarify, we need to do the following:
- Cache the Python function objects returned by
_numba_funcifybased on theirnodearguments (i.e. theApplynodes). - Store cached Python functions as persistent Python modules in
aesara.config.compiledir
We can repurpose the compile directory code for the C backend to do most of this, but, if there are other packages that would simplify the entirely process, we should look into those, too.
I'm thinking of the following approach:
- A class in numba that handles cache -- say
NumbaCache. - Some of the probable methods that this class would have is:
refresh-- Clear the cachehas_op_code-- Check whether the cache has Op implementation in Numbaadd_op_code-- Add the Numba function in the cache
- We will have a global object of this class that will be used in each of the functions.
Some of the questions in the implementation that we have are:
- Whether to use
.pyfile in the cache dir and append the function source code to it or to use pickling libraries likedill? -- I'm not entirely sure ifnumba's njit will work with serialization and deserialization indill. - How will we hash the
Opnode for a quick look-up? - Will we check the cache dir while visiting all the nodes in
numbabackend? -- I think that might affect compilation time in some cases.
Here's a high-level pseudocode outline of what we need: https://gist.github.com/brandonwillard/b9262d0eccb0e7016f836447ba8870fc.
Some of the questions in the implementation that we have are:
- Whether to use
.pyfile in the cache dir and append the function source code to it or to use pickling libraries likedill? -- I'm not entirely sure ifnumba's njit will work with serialization and deserialization indill.
Yes, we need to investigate whether or not we can directly use the pickled modules produced by dill or something similar. This is mocked-up in my outline.
- How will we hash the
Opnode for a quick look-up?
The outline uses a not-so-good pickle-to-SHA256 approach, but that's the basic idea for one approach.
- Will we check the cache dir while visiting all the nodes in
numbabackend? -- I think that might affect compilation time in some cases.
Per the outline, we can use an in-memory, file-backed cache like shelve. Exactly how we use it is the question, though.
Once we have completed #1470, I think we can then try to refactor some nodes like Shape_i and use partial functions.
Numba with partial funcs
import numba
from functools import partial
import numpy as np
import time
def some_func(arr, other_one):
t = 0.0
for a in arr:
t+=a
return t + other_one
arr = np.random.rand(1_000_000)
a = time.time()
f = numba.njit(some_func, cache=True)
f2 = partial(f, other_one=100)
tmp = f2(arr)
b = time.time()
print("Time 1:", b - a, tmp)
a = time.time()
f = numba.njit(some_func, cache=True)
f2 = partial(f, other_one=0)
tmp = f2(arr)
b = time.time()
print("Time 2:", b - a, tmp)
Once we have completed #1470, I think we can then try to refactor some nodes like
Shape_iand usepartialfunctions.
Are we sure that using partial is any better for compilation, though? I'm still not clear on that. It's definitely easier for us in the long-run, so I'm all for it, but we need to know when/if it will help.
Also, we can make the tests work using the same approach we have been using. That Numba disable-JIT approach would be a nicer, but it shouldn't be blocking anything.