HACK: Implement Numba VM with caching of individual nodes
Disclaimer: this is still 100% on hack status, and I don't understand half of the things I did
When we tried #811 it was obvious that numba compile times were prohibitive.
This PR tries a different approach (still at the hacking stage), of using a mode more like the CVM, where individual nodes are jit_compiled, but the whole graph (i.e., the "VM") is not. This allows reusing pre-compiled/cached nodes across different functions, bringing the compilation cost down.
It requires interfacing with the numba cache locator to direct it to cached objects, which requires defining our own cache keys. Numba usually uses the file line position and contents as the cache key, but this doesn't work for dynamically generated files (at least not if stored in a random temp file) nor really for nested functions like those built for Elemwise, Also some Ops are string-generated, and others are regular python functions with globals which numba can usually cache. All this has to be re-examined.
We are also not calling njit on inner Ops (the store_core_outputs / ScalarOp) of Elemwise, but instead doing register_jittable. This was needed for caching to work, because if we njit a function we always get a new object and once serialized the numba cache key will differ, whereas register_jitable overloads the function but returns it unchanged, which doesn't change the cache key.
This requires us to move the jit away from the dispatch functionality.
Results:
Second pass over tests/tensor/rewriting/test_basic.py (to allow compiling everything first):
2s with C_VM backend
54s with Numba backend
34s with Numba VM without cache
4s with Numba VM with cache
We're finally approaching the speed of the previous backend (at least for single function compilation + eval). Probably could get it there with more optimizing, but a small slowdown is acceptable.
TODO:
- [ ] We are still writing python strings to the filesystem to compile them, this is probably not needed as explored in #1326 (last commit?)
- [ ] We have to compile some functions that don't really need so we can cache it, such as with Elemwise. This is related to https://numba.discourse.group/t/caching-redefined-functions/3057 but I don't yet have a clear picture.
- [ ] Proper cache keys, I just hacked some quick things. Perhaps use the source code of the generated functions?
- [ ] Composite key is certainly broken
- [ ] Cache whole FunctionGraph, this would avoid recompiling identical graphs in the regular Numba mode, not just NumbaCVM (it's also needed for correct cache of Composite/Blockwise/Scan,OpFromGraph (i.e., anything with inner Ops)).
- [ ] Figure out what happens with Ops that run with object mode?
- [ ] Handle functions with pointers / large constants that can't traditionally be cached (not sure what's happening now). Related to https://github.com/numba/numba/issues/10098
- [ ] Benchmark slowdown from the "VM" approach in realistic functions. Consider using/adapting CVM to orchestrate the calls to the individuals nodes (would need to use the thunk approach). Right now the VM is the python source code generated by the outermost unjitted FunctionGraph
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
left some comments; hope you don't mind
I don't, but it's still too early for that sort of feedback. I'm just thinkering around at this point.