aesara icon indicating copy to clipboard operation
aesara copied to clipboard

Improve Numba caching and source generation

Open brandonwillard opened this issue 2 years ago • 4 comments

Looks like this wasn't given its own issue, but we need to improve the caching and way we generate source for Numba JITing.

To clarify, we need to do the following:

  • Cache the Python function objects returned by _numba_funcify based on their node arguments (i.e. the Apply nodes).
  • Store cached Python functions as persistent Python modules in aesara.config.compiledir

We can repurpose the compile directory code for the C backend to do most of this, but, if there are other packages that would simplify the entirely process, we should look into those, too.

brandonwillard avatar Feb 10 '23 17:02 brandonwillard

I'm thinking of the following approach:

  1. A class in numba that handles cache -- say NumbaCache.
  2. Some of the probable methods that this class would have is:
    • refresh -- Clear the cache
    • has_op_code -- Check whether the cache has Op implementation in Numba
    • add_op_code -- Add the Numba function in the cache
  3. We will have a global object of this class that will be used in each of the functions.

Some of the questions in the implementation that we have are:

  1. Whether to use .py file in the cache dir and append the function source code to it or to use pickling libraries like dill? -- I'm not entirely sure if numba's njit will work with serialization and deserialization in dill.
  2. How will we hash the Op node for a quick look-up?
  3. Will we check the cache dir while visiting all the nodes in numba backend? -- I think that might affect compilation time in some cases.

Smit-create avatar Feb 23 '23 14:02 Smit-create

Here's a high-level pseudocode outline of what we need: https://gist.github.com/brandonwillard/b9262d0eccb0e7016f836447ba8870fc.

Some of the questions in the implementation that we have are:

  1. Whether to use .py file in the cache dir and append the function source code to it or to use pickling libraries like dill? -- I'm not entirely sure if numba's njit will work with serialization and deserialization in dill.

Yes, we need to investigate whether or not we can directly use the pickled modules produced by dill or something similar. This is mocked-up in my outline.

  1. How will we hash the Op node for a quick look-up?

The outline uses a not-so-good pickle-to-SHA256 approach, but that's the basic idea for one approach.

  1. Will we check the cache dir while visiting all the nodes in numba backend? -- I think that might affect compilation time in some cases.

Per the outline, we can use an in-memory, file-backed cache like shelve. Exactly how we use it is the question, though.

brandonwillard avatar Feb 24 '23 22:02 brandonwillard

Once we have completed #1470, I think we can then try to refactor some nodes like Shape_i and use partial functions.

Numba with partial funcs

import numba
from functools import partial
import numpy as np
import time

def some_func(arr, other_one):
    t = 0.0
    for a in arr:
        t+=a
    return t + other_one


arr = np.random.rand(1_000_000)
a = time.time()
f = numba.njit(some_func, cache=True)
f2 = partial(f, other_one=100)
tmp = f2(arr)
b = time.time()
print("Time 1:", b - a, tmp)

a = time.time()
f = numba.njit(some_func, cache=True)
f2 = partial(f, other_one=0)
tmp = f2(arr)
b = time.time()
print("Time 2:", b - a, tmp)

Smit-create avatar Mar 14 '23 07:03 Smit-create

Once we have completed #1470, I think we can then try to refactor some nodes like Shape_i and use partial functions.

Are we sure that using partial is any better for compilation, though? I'm still not clear on that. It's definitely easier for us in the long-run, so I'm all for it, but we need to know when/if it will help.

Also, we can make the tests work using the same approach we have been using. That Numba disable-JIT approach would be a nicer, but it shouldn't be blocking anything.

brandonwillard avatar Mar 14 '23 15:03 brandonwillard