Reliable caching of Graphs and individual Ops in numba backend
Cherry picking changes from #1604
It seems that simply improving the caching of individual Ops gives us a lot of speedup when still jitting the whole graph.
THIS is still very dirty, and perhaps overkill. We're compile+exec every single Op even if they don't need our custom cache control. OTOH it's quite hard to know what will numba accept caching for or not, and as mentioned here, numba cache invalidation also leaves a lot to be desired: https://github.com/pymc-devs/pytensor/pull/1604#discussion_r2324786918
The compile+exec is needed to lift variables/functions out of the function closure into the global scope.
Otherwise numba will look into those to check if the cache is stale (numba always has the last word on whether a cache is stale or not). Depending on how they are serialized, these variables can look different even if they haven't changed and the function is exactly the same as before.
This is pure gaming numba to our purposes.
Benchmarking
And here are the timings for compiling the radon model repeatedly:
Before
--------------------------------------------------------------------------------------------------------- benchmark: 3 tests --------------------------------------------------------------------------------------------------------
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_radon_model_repeated_compile_benchmark[C_VM] 641.6810 (1.0) 923.9313 (1.03) 728.9724 (1.01) 118.2841 (1.16) 668.2491 (1.0) 149.5554 (1.27) 1;0 1.3718 (0.99) 5 1
test_radon_model_repeated_compile_benchmark[C] 653.8087 (1.02) 895.6774 (1.0) 723.1066 (1.0) 101.6445 (1.0) 675.7941 (1.01) 117.9712 (1.0) 1;0 1.3829 (1.0) 5 1
test_radon_model_repeated_compile_benchmark[NUMBA] 7,505.5632 (11.70) 8,140.9498 (9.09) 7,836.1051 (10.84) 241.2376 (2.37) 7,894.4678 (11.81) 332.0119 (2.81) 2;0 0.1276 (0.09) 5 1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
After
------------------------------------------------------------------------------------------------------ benchmark: 3 tests -----------------------------------------------------------------------------------------------------
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_radon_model_repeated_compile_benchmark[C] 676.8789 (1.0) 1,028.5802 (1.05) 786.7192 (1.0) 140.3314 (1.41) 748.9989 (1.0) 139.2424 (1.25) 1;0 1.2711 (1.0) 5 1
test_radon_model_repeated_compile_benchmark[C_VM] 740.0859 (1.09) 980.4151 (1.0) 811.9622 (1.03) 99.4709 (1.0) 759.7109 (1.01) 111.1999 (1.0) 1;0 1.2316 (0.97) 5 1
test_radon_model_repeated_compile_benchmark[NUMBA] 762.9275 (1.13) 1,027.3377 (1.05) 888.7869 (1.13) 102.8226 (1.03) 900.3118 (1.20) 153.7402 (1.38) 2;0 1.1251 (0.89) 5 1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
And compiling slight variations of the same model (after clearing the cache for both backends)
Before
------------------------------------------------------------------------
Name (time in s) Runtime
------------------------------------------------------------------------
test_radon_model_compile_variants_benchmark[C_VM] 14.6317 (1.0)
test_radon_model_compile_variants_benchmark[C] 37.0302 (2.53)
test_radon_model_compile_variants_benchmark[NUMBA] 51.8231 (3.54)
------------------------------------------------------------------------
After
-----------------------------------------------------------------------
Name (time in s) Runtime
-----------------------------------------------------------------------
test_radon_model_compile_variants_benchmark[C_VM] 16.4719 (1.0)
test_radon_model_compile_variants_benchmark[NUMBA] 26.0248 (1.58)
test_radon_model_compile_variants_benchmark[C] 37.1087 (2.25)
-----------------------------------------------------------------------
And this comes at no cost in evaluation runtime. Unlike the VM approach which is 2x slower in the C backend and many times over in the naive impl in #1604 (Numba overhead per individual jitted function added up there)
-------------------------------------------------------------------------------------------------- benchmark: 4 tests --------------------------------------------------------------------------------------------------
Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_radon_model_call_benchmark[NUMBA] 21.3900 (1.0) 35.4870 (1.0) 25.0974 (1.0) 5.9128 (1.88) 22.2920 (1.0) 5.2898 (13.20) 1;1 39.8448 (1.0) 5 1
test_radon_model_call_benchmark[C] 22.6220 (1.06) 72.1550 (2.03) 27.6587 (1.10) 3.1417 (1.0) 27.8320 (1.25) 0.4007 (1.0) 2577;3297 36.1550 (0.91) 9691 1
test_radon_model_call_benchmark[C_VM_NOGC] 33.3120 (1.56) 1,177.2980 (33.18) 46.5852 (1.86) 16.1183 (5.13) 40.5760 (1.82) 12.9245 (32.25) 790;286 21.4661 (0.54) 8988 1
test_radon_model_call_benchmark[C_VM] 44.9940 (2.10) 633.7580 (17.86) 52.8585 (2.11) 11.7635 (3.74) 52.3990 (2.35) 6.5520 (16.35) 570;652 18.9184 (0.47) 8000 1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Conclusion so far
- If you're recompiling the very same function numba is now as fast as the other backends (and 10x faster than before)
- If you're compiling slightly different functions numba is 2x slower than the default backend
C_VM, but faster than the "equivalent" fully compiled C. Compared to itself, it is 2x faster than before, so in general numba will compile anywhere between 2x and 10x faster than before, for similar kinds of models.
TODO:
- [x] Hash constants correctly
- [x] Investigate whether doing
sha256(sha256(key1), sha256(key2))has much higher collision chances thansha256(key1, key2) - [x] Split numba benchmark tests and add a numba no-cache mode to the test
Follow Up PRs
- [ ] Cache functions with pointers to scipy blas/cython functions: https://github.com/numba/numba/issues/10098#issuecomment-2959818159
- [x] Cache Scan
- [ ] Keep trying the numba vm idea with the rust as the glue that @aseyboldt is developing
- [ ] Reach to numba devs to see if we can find a less "gaming" approach, as this will be subject to numba whims
the suspense is killing me
Updated it's 2.5x faster than before
Is this Good Enough(TM)?
Is this Good Enough(TM)?
I've managed to cache the whole function graph (still in hack mode), which means if you recompile the exact same graph it's as fast to compile as the other backends (and executes faster, so that's nice). This is not so rare when working in jupyter notebooks and the like.
For compiling slightly different versions (new test) it's 2x faster than before and no longer the slowest backend. The closest is the C (as opposed to C VM), which also compiles the whole graph.
It's still 2x slower then the C VM. This is as expected, that's why the Theano guys moved away from the C as default. However that may actually be fine for now? We can tell users to switch to the CVM backend if the compile times are prohibitive?
Neat! yeah, sounds like this is good enough to switch the default backend to numba.
Okay this is getting there. Missing some docstrings and minor cleanup. The big pain point are the blas/lapack functions that we'll have to handle manually as well, but I would leave those to a separate PR.
I can re-run the whole numba CI in 5 minutes now, of which the uncacheable blas/lapack stuff takes 3-4 minutes.
This will also unblock https://github.com/pymc-devs/pytensor/pull/1445 which was blocked by our clumsy/eager attempt to cache stuff, even the uncacheable. The new system allows us to cleanly inform higher-order functions when a sub-function is uncacheable and therefore the functions using it as well.
Codecov Report
:x: Patch coverage is 91.06830% with 51 lines in your changes missing coverage. Please review.
:white_check_mark: Project coverage is 81.75%. Comparing base (60ba7c7) to head (9158fce).
:warning: Report is 4 commits behind head on main.
:x: Your patch check has failed because the patch coverage (91.06%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.
Additional details and impacted files
@@ Coverage Diff @@
## main #1637 +/- ##
==========================================
+ Coverage 81.67% 81.75% +0.08%
==========================================
Files 244 248 +4
Lines 53558 53822 +264
Branches 9433 9459 +26
==========================================
+ Hits 43741 44002 +261
+ Misses 7337 7333 -4
- Partials 2480 2487 +7
| Files with missing lines | Coverage Δ | |
|---|---|---|
| pytensor/compile/mode.py | 85.13% <100.00%> (+0.13%) |
:arrow_up: |
| pytensor/configparser.py | 92.60% <100.00%> (+0.04%) |
:arrow_up: |
| pytensor/link/numba/dispatch/__init__.py | 100.00% <100.00%> (ø) |
|
| pytensor/link/numba/dispatch/extra_ops.py | 96.92% <100.00%> (+0.82%) |
:arrow_up: |
| ...sor/link/numba/dispatch/linalg/decomposition/lu.py | 66.66% <100.00%> (ø) |
|
| ...or/link/numba/dispatch/linalg/solve/tridiagonal.py | 55.39% <100.00%> (ø) |
|
| pytensor/link/numba/dispatch/nlinalg.py | 100.00% <100.00%> (ø) |
|
| pytensor/link/numba/dispatch/shape.py | 100.00% <100.00%> (ø) |
|
| pytensor/link/numba/dispatch/signal/conv.py | 32.69% <100.00%> (ø) |
|
| pytensor/link/numba/dispatch/slinalg.py | 68.54% <100.00%> (+0.14%) |
:arrow_up: |
| ... and 16 more |
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
Just wrestling with mypy at this point, everything ready otherwise
One more manual benchmarking of the radon model:
Before
No caching: 7.5s
Caching except outer FunctionGraph: 6s
Caching: 6s
After
No caching: 8.5s
Caching except outer FunctionGraph: 1.8s
Caching: 0.6s
Starting fresh (without caching) seems a bit slower. However we later benefit more from individual Ops compilation across different calls.
I'm not sure what explains the difference without caching. I got the timings by setting pytensor.config.numba__cache=False. I guess the register_jitable instead of immediate njit is slower? Or Numba is doing within-runtime caching, and it did so better before than now. I'm not creating intermediate wrapper functions when cache is disabled so it's not that there is more fluff.
Otherwise, all the key computing could be explaining the slowdown?
Note that this means our numba CI won't see much speedup (if any at all). It is mostly testing single Ops one at a time, so it doesn't benefit much from sharing of cached Ops (other than the occasional dimshuffle/indexing that is more recurring). If you run it twice with caching enabled you'll see a huge difference so it's still a win for dev time.
Edit: I think I brought the non-cache difference slightly down, although measurements are still slow and noisy.
Do we want to keep the eval_obj_mode thing in compare_numba_and_py? It requires us to always use numba_basic.numba_njit instead of importing numba_njit directly because that' can't be properly patched. Also if you forget to use the monkeyfiable method in a file you may get a criptic segfault/numba error, because it tries to use a njit function that is not meant to be called directly from python (that's the default when you call numba_njit, which has final_function=False
Do you know of any places where we need it?
Do you know of any places where we need it?
We don't need it. This is just for code-coverage and supposedly to find trivial python errors in python-mode, instead of the unreadable numba tracebacks when developing.
Tests are passing and all immediate TODO have been addressed.
but can't we usually just ask the Op itself to give us a cache key, and move some of the new code there?
I think this is backend specific. The same Op may require distinct functions in one backend (like numba does with Subtensor) but a single one in another (jax is just x[indices]). But I may be missing what you had in mind specifically?