skglm
skglm copied to clipboard
ENH Cache Numba compilation for the user
import time
import numpy as np
from numpy.linalg import norm
from sklearn.linear_model import Lasso as Lasso_sk
from skglm.estimators import Lasso
n_samples = 100
n_features = 10_000
X = np.random.normal(0, 1, (n_samples, n_features))
y = np.random.normal(0, 1, (n_samples,))
alpha_max = norm(X.T @ y, ord=np.inf) / n_samples
alpha = alpha_max * 0.1
start = time.time()
clf = Lasso(alpha).fit(X, y)
print("skglm:", time.time() - start)
start = time.time()
clf = Lasso_sk(alpha).fit(X, y)
print("sklearn:", time.time() - start)
This script gives:
skglm: 4.0232319831848145
sklearn: 0.2305459976196289
This is due to the compilation cost. We should cache this compilation once and for all, ideally during install (by pre-building/pre-compiling the IR generated by Numba) or when the user first runs a script, using njit(cache=True)
.
If I got your point, you suggest shipping skglm
with a pre-compiled numba
code, right?
It's a really tempting suggestion as it will eliminate the overhead of the first run.
Have you tried to upload code to PyPI with pre-compiled Numba code? If so, does it work as expected?
For now, I've only tried adding cache=True
to the njit
decorator, does not change anything. I was wondering if Numba compilation could not be included in a wheel that we can ship on PyPI.
Those are very interesting suggestions, and that would be a major plus indeed if possible !
Did you look at Ahead of time compilation too ? https://numba.pydata.org/numba-doc/latest/user/pycc.html
I did, it's what we need. A few comments on the limitations though:
1. AOT compilation only allows for regular functions, not ufuncs.
2. You have to specify function signatures explicitly.
3. Each exported function can have only one signature (but you can export several different signatures under different names).
4. AOT compilation produces generic code for your CPU’s architectural family (for example “x86-64”), while JIT compilation produces code optimized for your particular CPU model.
- I'm not sure of this one: ufuncs are overloaded Numpy functions by Numba. So this would be a problem.
- This one will give us a bit of work, especially since we want to support both float64 and float32 types.
- Same
- Might see a drop in performance, to be investigated with benchmarks.
Another thing to worry about: jitclass
is not really supported by AOT.
For the build integration in setup.py
: https://numba.pydata.org/numba-doc/dev/user/pycc.html#distutils-integration
A code snippet to AOT-compile at first import: https://github.com/pymatting/pymatting/blob/838171bdf13ddd474c1e81b7a4f1427bf6da6703/pymatting_aot/cc.py#L1