pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Reliable caching of Graphs and individual Ops in numba backend

Open ricardoV94 opened this issue 2 months ago • 13 comments

Cherry picking changes from #1604

It seems that simply improving the caching of individual Ops gives us a lot of speedup when still jitting the whole graph.

THIS is still very dirty, and perhaps overkill. We're compile+exec every single Op even if they don't need our custom cache control. OTOH it's quite hard to know what will numba accept caching for or not, and as mentioned here, numba cache invalidation also leaves a lot to be desired: https://github.com/pymc-devs/pytensor/pull/1604#discussion_r2324786918

The compile+exec is needed to lift variables/functions out of the function closure into the global scope.

Otherwise numba will look into those to check if the cache is stale (numba always has the last word on whether a cache is stale or not). Depending on how they are serialized, these variables can look different even if they haven't changed and the function is exactly the same as before.

This is pure gaming numba to our purposes.

Benchmarking

And here are the timings for compiling the radon model repeatedly:

Before
--------------------------------------------------------------------------------------------------------- benchmark: 3 tests --------------------------------------------------------------------------------------------------------
Name (time in ms)                                             Min                   Max                  Mean              StdDev                Median                 IQR            Outliers     OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_radon_model_repeated_compile_benchmark[C_VM]        641.6810 (1.0)        923.9313 (1.03)       728.9724 (1.01)     118.2841 (1.16)       668.2491 (1.0)      149.5554 (1.27)          1;0  1.3718 (0.99)          5           1
test_radon_model_repeated_compile_benchmark[C]           653.8087 (1.02)       895.6774 (1.0)        723.1066 (1.0)      101.6445 (1.0)        675.7941 (1.01)     117.9712 (1.0)           1;0  1.3829 (1.0)           5           1
test_radon_model_repeated_compile_benchmark[NUMBA]     7,505.5632 (11.70)    8,140.9498 (9.09)     7,836.1051 (10.84)    241.2376 (2.37)     7,894.4678 (11.81)    332.0119 (2.81)          2;0  0.1276 (0.09)          5           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After
------------------------------------------------------------------------------------------------------ benchmark: 3 tests -----------------------------------------------------------------------------------------------------
Name (time in ms)                                           Min                   Max                Mean              StdDev              Median                 IQR            Outliers     OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_radon_model_repeated_compile_benchmark[C]         676.8789 (1.0)      1,028.5802 (1.05)     786.7192 (1.0)      140.3314 (1.41)     748.9989 (1.0)      139.2424 (1.25)          1;0  1.2711 (1.0)           5           1
test_radon_model_repeated_compile_benchmark[C_VM]      740.0859 (1.09)       980.4151 (1.0)      811.9622 (1.03)      99.4709 (1.0)      759.7109 (1.01)     111.1999 (1.0)           1;0  1.2316 (0.97)          5           1
test_radon_model_repeated_compile_benchmark[NUMBA]     762.9275 (1.13)     1,027.3377 (1.05)     888.7869 (1.13)     102.8226 (1.03)     900.3118 (1.20)     153.7402 (1.38)          2;0  1.1251 (0.89)          5           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

And compiling slight variations of the same model (after clearing the cache for both backends)

Before
------------------------------------------------------------------------
Name (time in s)                                           Runtime          
------------------------------------------------------------------------
test_radon_model_compile_variants_benchmark[C_VM]      14.6317 (1.0)    
test_radon_model_compile_variants_benchmark[C]         37.0302 (2.53)   
test_radon_model_compile_variants_benchmark[NUMBA]     51.8231 (3.54)   
------------------------------------------------------------------------

After
-----------------------------------------------------------------------
Name (time in s)                                           Runtime         
-----------------------------------------------------------------------
test_radon_model_compile_variants_benchmark[C_VM]      16.4719 (1.0)   
test_radon_model_compile_variants_benchmark[NUMBA]     26.0248 (1.58)  
test_radon_model_compile_variants_benchmark[C]         37.1087 (2.25)  
-----------------------------------------------------------------------

And this comes at no cost in evaluation runtime. Unlike the VM approach which is 2x slower in the C backend and many times over in the naive impl in #1604 (Numba overhead per individual jitted function added up there)

-------------------------------------------------------------------------------------------------- benchmark: 4 tests --------------------------------------------------------------------------------------------------
Name (time in us)                                  Min                   Max               Mean             StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_radon_model_call_benchmark[NUMBA]         21.3900 (1.0)         35.4870 (1.0)      25.0974 (1.0)       5.9128 (1.88)     22.2920 (1.0)       5.2898 (13.20)         1;1       39.8448 (1.0)           5           1
test_radon_model_call_benchmark[C]             22.6220 (1.06)        72.1550 (2.03)     27.6587 (1.10)      3.1417 (1.0)      27.8320 (1.25)      0.4007 (1.0)     2577;3297       36.1550 (0.91)       9691           1
test_radon_model_call_benchmark[C_VM_NOGC]     33.3120 (1.56)     1,177.2980 (33.18)    46.5852 (1.86)     16.1183 (5.13)     40.5760 (1.82)     12.9245 (32.25)     790;286       21.4661 (0.54)       8988           1
test_radon_model_call_benchmark[C_VM]          44.9940 (2.10)       633.7580 (17.86)    52.8585 (2.11)     11.7635 (3.74)     52.3990 (2.35)      6.5520 (16.35)     570;652       18.9184 (0.47)       8000           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Conclusion so far

  • If you're recompiling the very same function numba is now as fast as the other backends (and 10x faster than before)
  • If you're compiling slightly different functions numba is 2x slower than the default backend C_VM, but faster than the "equivalent" fully compiled C. Compared to itself, it is 2x faster than before, so in general numba will compile anywhere between 2x and 10x faster than before, for similar kinds of models.

TODO:

  • [x] Hash constants correctly
  • [x] Investigate whether doing sha256(sha256(key1), sha256(key2)) has much higher collision chances than sha256(key1, key2)
  • [x] Split numba benchmark tests and add a numba no-cache mode to the test

Follow Up PRs

  • [ ] Cache functions with pointers to scipy blas/cython functions: https://github.com/numba/numba/issues/10098#issuecomment-2959818159
  • [x] Cache Scan
  • [ ] Keep trying the numba vm idea with the rust as the glue that @aseyboldt is developing
  • [ ] Reach to numba devs to see if we can find a less "gaming" approach, as this will be subject to numba whims

ricardoV94 avatar Oct 07 '25 18:10 ricardoV94

image

the suspense is killing me

twiecki avatar Oct 08 '25 03:10 twiecki

Updated it's 2.5x faster than before

ricardoV94 avatar Oct 08 '25 20:10 ricardoV94

Is this Good Enough(TM)?

twiecki avatar Oct 09 '25 04:10 twiecki

Is this Good Enough(TM)?

I've managed to cache the whole function graph (still in hack mode), which means if you recompile the exact same graph it's as fast to compile as the other backends (and executes faster, so that's nice). This is not so rare when working in jupyter notebooks and the like.

For compiling slightly different versions (new test) it's 2x faster than before and no longer the slowest backend. The closest is the C (as opposed to C VM), which also compiles the whole graph.

It's still 2x slower then the C VM. This is as expected, that's why the Theano guys moved away from the C as default. However that may actually be fine for now? We can tell users to switch to the CVM backend if the compile times are prohibitive?

ricardoV94 avatar Oct 09 '25 12:10 ricardoV94

Neat! yeah, sounds like this is good enough to switch the default backend to numba.

twiecki avatar Oct 10 '25 02:10 twiecki

Okay this is getting there. Missing some docstrings and minor cleanup. The big pain point are the blas/lapack functions that we'll have to handle manually as well, but I would leave those to a separate PR.

I can re-run the whole numba CI in 5 minutes now, of which the uncacheable blas/lapack stuff takes 3-4 minutes.

This will also unblock https://github.com/pymc-devs/pytensor/pull/1445 which was blocked by our clumsy/eager attempt to cache stuff, even the uncacheable. The new system allows us to cleanly inform higher-order functions when a sub-function is uncacheable and therefore the functions using it as well.

ricardoV94 avatar Oct 23 '25 21:10 ricardoV94

Codecov Report

:x: Patch coverage is 91.06830% with 51 lines in your changes missing coverage. Please review. :white_check_mark: Project coverage is 81.75%. Comparing base (60ba7c7) to head (9158fce). :warning: Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/basic.py 86.55% 12 Missing and 4 partials :warning:
pytensor/link/numba/cache.py 85.18% 4 Missing and 4 partials :warning:
pytensor/bin/pytensor_cache.py 16.66% 5 Missing :warning:
pytensor/link/numba/dispatch/compile_ops.py 92.64% 3 Missing and 2 partials :warning:
pytensor/link/numba/dispatch/random.py 66.66% 5 Missing :warning:
pytensor/link/numba/dispatch/blockwise.py 77.77% 4 Missing :warning:
pytensor/link/numba/dispatch/elemwise.py 96.49% 2 Missing :warning:
pytensor/link/numba/dispatch/scan.py 77.77% 1 Missing and 1 partial :warning:
pytensor/link/numba/dispatch/vectorize_codegen.py 75.00% 1 Missing and 1 partial :warning:
pytensor/configdefaults.py 66.66% 1 Missing :warning:
... and 1 more

:x: Your patch check has failed because the patch coverage (91.06%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1637      +/-   ##
==========================================
+ Coverage   81.67%   81.75%   +0.08%     
==========================================
  Files         244      248       +4     
  Lines       53558    53822     +264     
  Branches     9433     9459      +26     
==========================================
+ Hits        43741    44002     +261     
+ Misses       7337     7333       -4     
- Partials     2480     2487       +7     
Files with missing lines Coverage Δ
pytensor/compile/mode.py 85.13% <100.00%> (+0.13%) :arrow_up:
pytensor/configparser.py 92.60% <100.00%> (+0.04%) :arrow_up:
pytensor/link/numba/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/extra_ops.py 96.92% <100.00%> (+0.82%) :arrow_up:
...sor/link/numba/dispatch/linalg/decomposition/lu.py 66.66% <100.00%> (ø)
...or/link/numba/dispatch/linalg/solve/tridiagonal.py 55.39% <100.00%> (ø)
pytensor/link/numba/dispatch/nlinalg.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/shape.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/signal/conv.py 32.69% <100.00%> (ø)
pytensor/link/numba/dispatch/slinalg.py 68.54% <100.00%> (+0.14%) :arrow_up:
... and 16 more

... and 7 files with indirect coverage changes

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Oct 23 '25 21:10 codecov[bot]

Just wrestling with mypy at this point, everything ready otherwise

ricardoV94 avatar Oct 25 '25 11:10 ricardoV94

One more manual benchmarking of the radon model:

Before

No caching: 7.5s
Caching except outer FunctionGraph: 6s
Caching: 6s

After

No caching: 8.5s
Caching except outer FunctionGraph: 1.8s
Caching: 0.6s

Starting fresh (without caching) seems a bit slower. However we later benefit more from individual Ops compilation across different calls.

I'm not sure what explains the difference without caching. I got the timings by setting pytensor.config.numba__cache=False. I guess the register_jitable instead of immediate njit is slower? Or Numba is doing within-runtime caching, and it did so better before than now. I'm not creating intermediate wrapper functions when cache is disabled so it's not that there is more fluff.

Otherwise, all the key computing could be explaining the slowdown?

Note that this means our numba CI won't see much speedup (if any at all). It is mostly testing single Ops one at a time, so it doesn't benefit much from sharing of cached Ops (other than the occasional dimshuffle/indexing that is more recurring). If you run it twice with caching enabled you'll see a huge difference so it's still a win for dev time.

Edit: I think I brought the non-cache difference slightly down, although measurements are still slow and noisy.

ricardoV94 avatar Oct 25 '25 15:10 ricardoV94

Do we want to keep the eval_obj_mode thing in compare_numba_and_py? It requires us to always use numba_basic.numba_njit instead of importing numba_njit directly because that' can't be properly patched. Also if you forget to use the monkeyfiable method in a file you may get a criptic segfault/numba error, because it tries to use a njit function that is not meant to be called directly from python (that's the default when you call numba_njit, which has final_function=False

ricardoV94 avatar Oct 27 '25 14:10 ricardoV94

Do you know of any places where we need it?

jessegrabowski avatar Oct 27 '25 14:10 jessegrabowski

Do you know of any places where we need it?

We don't need it. This is just for code-coverage and supposedly to find trivial python errors in python-mode, instead of the unreadable numba tracebacks when developing.

ricardoV94 avatar Oct 27 '25 15:10 ricardoV94

Tests are passing and all immediate TODO have been addressed.

ricardoV94 avatar Oct 27 '25 16:10 ricardoV94

but can't we usually just ask the Op itself to give us a cache key, and move some of the new code there?

I think this is backend specific. The same Op may require distinct functions in one backend (like numba does with Subtensor) but a single one in another (jax is just x[indices]). But I may be missing what you had in mind specifically?

ricardoV94 avatar Nov 03 '25 10:11 ricardoV94