pytensor
pytensor copied to clipboard
Implement `BandedDot` `Op`
Description
This PR adds a BandedDot Op that uses gbmv to do matrix-vector multiplication for the case that A is a banded matrix.
In my testing, I found that this case sped up computation significantly. Benchmarking against Pytensor's dot, however, the current implementation is significantly slower:
------------------------------------------------------------------------------------------------- benchmark: 8 tests ------------------------------------------------------------------------------------------------
Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_dot_perf[10] 1.7500 (1.0) 17.3330 (1.0) 1.9054 (1.0) 0.1292 (1.0) 1.9160 (1.0) 0.0420 (1.0) 585;1740 524,831.2234 (1.0) 38401 1
test_banded_dot_perf[10] 19.9580 (11.40) 13,765.1250 (794.16) 32.5111 (17.06) 282.5468 (>1000.0) 20.5830 (10.74) 0.3750 (8.93) 6;349 30,758.7051 (0.06) 3275 1
test_dot_perf[100] 2.4580 (1.40) 42.5420 (2.45) 2.7856 (1.46) 0.3265 (2.53) 2.7500 (1.44) 0.0420 (1.0) 343;7436 358,988.7425 (0.68) 71429 1
test_banded_dot_perf[100] 19.8330 (11.33) 15,203.3750 (877.13) 30.9185 (16.23) 193.8617 (>1000.0) 20.9580 (10.94) 0.4160 (9.90) 51;3057 32,343.1413 (0.06) 20566 1
test_dot_perf[1000] 15.0000 (8.57) 61.5000 (3.55) 16.6383 (8.73) 1.4182 (10.98) 17.2920 (9.03) 2.2080 (52.57) 905;126 60,102.3508 (0.11) 18377 1
test_banded_dot_perf[1000] 27.0420 (15.45) 423.8750 (24.45) 32.9042 (17.27) 5.2005 (40.25) 32.6250 (17.03) 0.6250 (14.88) 129;1334 30,391.2634 (0.06) 12501 1
test_dot_perf[10_000] 3,369.4580 (>1000.0) 5,011.3330 (289.12) 3,412.7784 (>1000.0) 119.9981 (928.81) 3,394.5625 (>1000.0) 17.2910 (411.69) 4;25 293.0164 (0.00) 198 1
test_banded_dot_perf[10_000] 109.9170 (62.81) 611.5830 (35.28) 139.2751 (73.10) 52.3002 (404.81) 116.5000 (60.80) 14.0000 (333.33) 472;678 7,180.0341 (0.01) 3386 1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
I guess there's some major overhead from doing the diagonal extractions and looking up the blas function in python? This could and should probably be a C Op, but I'm not sure I have time to realistically dig into all that anytime soon. Help wanted, at any rate.
Related Issue
- [ ] Closes #1415
- [ ] Related to #1323
Checklist
- [ ] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [x] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
📚 Documentation preview 📚: https://pytensor--1416.org.readthedocs.build/en/1416/
I added trust_input and I also load the BLAS functions once on import and save them. So that should reduce some of the most obvious sources of python overhead. New benchmarks (note that they're in ns now, not us):
------------------------------------------------------------------------------------------------------------------- benchmark: 8 tests -------------------------------------------------------------------------------------------------------------------
Name (time in ns) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_banded_dot_perf[10-dot] 541.9988 (1.0) 4,292.0001 (1.0) 638.1136 (1.0) 51.0902 (1.0) 625.0011 (1.0) 41.0000 (40.91) 1506;209 1,567,119.1257 (1.0) 15636 1
test_banded_dot_perf[10-banded_dot] 17,500.0005 (32.29) 418,167.0010 (97.43) 18,191.1183 (28.51) 3,829.7598 (74.96) 18,083.0011 (28.93) 167.0014 (166.62) 70;630 54,971.8815 (0.04) 11353 1
test_banded_dot_perf[100-dot] 1,209.0004 (2.23) 23,959.0008 (5.58) 1,340.3628 (2.10) 103.1441 (2.02) 1,333.0009 (2.13) 1.0023 (1.0) 1217;34675 746,066.6804 (0.48) 88889 1
test_banded_dot_perf[100-banded_dot] 17,542.0009 (32.37) 77,083.9997 (17.96) 18,240.8191 (28.59) 1,230.1810 (24.08) 18,000.0006 (28.80) 250.0001 (249.44) 654;2431 54,822.0996 (0.03) 19018 1
test_banded_dot_perf[1000-dot] 13,291.9995 (24.52) 49,874.9996 (11.62) 15,195.7498 (23.81) 1,137.7872 (22.27) 15,833.0004 (25.33) 1,832.9993 (>1000.0) 2954;119 65,807.8747 (0.04) 22347 1
test_banded_dot_perf[1000-banded_dot] 24,624.9983 (45.43) 74,874.9990 (17.45) 30,233.2753 (47.38) 1,347.0049 (26.37) 30,125.0002 (48.20) 375.0010 (374.15) 874;1333 33,076.1385 (0.02) 15595 1
test_banded_dot_perf[10_000-dot] 3,394,874.9988 (>1000.0) 5,084,541.9992 (>1000.0) 3,585,834.0104 (>1000.0) 191,227.5142 (>1000.0) 3,558,604.5005 (>1000.0) 199,729.5003 (>1000.0) 16;3 278.8752 (0.00) 192 1
test_banded_dot_perf[10_000-banded_dot] 105,208.0006 (194.11) 389,250.0008 (90.69) 124,879.6041 (195.70) 35,967.3472 (704.00) 110,375.0001 (176.60) 8,343.4998 (>1000.0) 320;440 8,007.7128 (0.01) 2665 1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
I think the Op is fine, specially if we are not trying to introduce it automatically via rewrites. If we are we may consider the backend (once we have it in numba I suspect it will win for smaller matrices) and/or static shapes if we think the worse-case penalty is still too big
Benchmark after tuning up the _to_banded_form function:
------------------------------------------------------------------------------------------------------------------- benchmark: 8 tests ------------------------------------------------------------------------------------------------------------------
Name (time in ns) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_banded_dot_perf[10-dot] 499.9965 (1.0) 55,500.0006 (1.41) 665.4888 (1.0) 390.9718 (1.0) 666.0011 (1.0) 42.0005 (1.00) 31;2639 1,502,654.9287 (1.0) 32129 1
test_banded_dot_perf[10-banded_dot] 2,832.9996 (5.67) 71,957.9984 (1.82) 3,356.9474 (5.04) 782.8860 (2.00) 3,332.9998 (5.00) 332.9988 (7.93) 1874;2239 297,889.6806 (0.20) 32833 1
test_banded_dot_perf[100-dot] 1,000.0003 (2.00) 58,208.9997 (1.47) 1,191.9862 (1.79) 396.5918 (1.01) 1,166.9981 (1.75) 41.9968 (1.0) 305;3163 838,935.8643 (0.56) 91258 1
test_banded_dot_perf[100-banded_dot] 3,332.9998 (6.67) 39,499.9988 (1.0) 3,874.8349 (5.82) 471.5917 (1.21) 3,875.0004 (5.82) 84.0009 (2.00) 1020;11972 258,075.5142 (0.17) 71008 1
test_banded_dot_perf[1000-dot] 13,584.0019 (27.17) 118,374.9991 (3.00) 16,143.5130 (24.26) 1,984.1144 (5.07) 16,291.0001 (24.46) 2,042.0011 (48.62) 1390;171 61,944.3861 (0.04) 14202 1
test_banded_dot_perf[1000-banded_dot] 8,167.0005 (16.33) 68,749.9996 (1.74) 10,694.7895 (16.07) 1,131.4230 (2.89) 11,000.0001 (16.52) 416.9997 (9.93) 6811;7582 93,503.4764 (0.06) 32521 1
test_banded_dot_perf[10_000-dot] 3,379,415.9972 (>1000.0) 3,680,959.0019 (93.19) 3,463,207.0645 (>1000.0) 79,485.8545 (203.30) 3,434,124.9993 (>1000.0) 114,541.9992 (>1000.0) 6;0 288.7497 (0.00) 31 1
test_banded_dot_perf[10_000-banded_dot] 93,582.9994 (187.17) 294,458.0010 (7.45) 100,154.2338 (150.50) 22,660.4163 (57.96) 95,479.0012 (143.36) 2,083.4996 (49.61) 10;27 9,984.6004 (0.01) 248 1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
That looks much better!
I agree numba will probably be better across the board. I'd really like this Op to win on the 100x100 case, that's already a pretty big matrix. 1000x1000 and 10,000x10,000 doesn't really show up in nature too often
100x100 is 1us, you are at the edge of python overhead there. Calling an identity PyTensor function and no trust_input is 300-500ns. Calling np.zeros is like 100-200ns. That means you would basically need to have no python overhead whatsoever
Edit: those are on my machine, don't know about yours
This is the best I think we can get out of this in python?
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl):
kl = self.lower_diags
ku = self.upper_diags
if node.outputs[0].dtype == "float64":
gbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64")
else:
gbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32")
ab_size = kl + ku + 1
a_storage = storage_map[node.inputs[0]]
b_storage = storage_map[node.inputs[1]]
out_storage = storage_map[node.outputs[0]]
out_computed = compute_map[node.outputs[0]] if compute_map is not None else [False]
def thunk(
a_storage=a_storage,
b_storage=b_storage,
out_storage=out_storage,
out_computed=out_computed,
kl=kl,
ku=ku,
ab_size=ab_size,
gbmv=gbmv,
):
A = a_storage[0]
b = b_storage[0]
m, n = A.shape
ab = np.zeros((ab_size, n), dtype=A.dtype, order="C")
for i, k in enumerate(range(ku, -kl - 1, -1)):
if k > 0:
ab[i, k:] = diag(A, k=k)
else:
ab[i, :n + k] = diag(A, k=k)
out_storage[0] = gbmv(m, n, kl, ku, 1, ab, b)
out_computed[0] = True
return thunk
I'm not saying we should do that, but it gives you a lower bound on what to expect from your micro-optimizations
Here's what the thunk version benchmarks as for me:
------------------------------------------------------------------------------------------------------------------- benchmark: 8 tests ------------------------------------------------------------------------------------------------------------------
Name (time in ns) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_banded_dot_perf[10-dot] 582.9970 (1.0) 7,208.0002 (1.0) 648.7823 (1.0) 105.4763 (1.0) 625.0011 (1.0) 41.9968 (1.0) 184;252 1,541,349.0560 (1.0) 18434 1
test_banded_dot_perf[10-banded_dot] 2,749.9991 (4.72) 28,665.9997 (3.98) 2,954.8453 (4.55) 350.8606 (3.33) 2,917.0005 (4.67) 42.9973 (1.02) 555;5229 338,427.1940 (0.22) 39868 1
test_banded_dot_perf[100-dot] 1,042.0008 (1.79) 15,624.9989 (2.17) 1,178.4495 (1.82) 197.8076 (1.88) 1,166.9981 (1.87) 42.0005 (1.00) 512;1917 848,572.6277 (0.55) 100848 1
test_banded_dot_perf[100-banded_dot] 3,166.9988 (5.43) 33,166.9980 (4.60) 3,418.6797 (5.27) 364.1081 (3.45) 3,415.9966 (5.47) 83.0005 (1.98) 826;2615 292,510.5862 (0.19) 65574 1
test_banded_dot_perf[1000-dot] 13,334.0000 (22.87) 45,625.0018 (6.33) 15,480.3238 (23.86) 1,366.7475 (12.96) 15,957.9977 (25.53) 1,958.0002 (46.62) 1490;223 64,598.1318 (0.04) 20426 1
test_banded_dot_perf[1000-banded_dot] 8,541.9997 (14.65) 50,667.0003 (7.03) 10,089.9543 (15.55) 777.8152 (7.37) 10,416.9994 (16.67) 1,290.9986 (30.74) 11635;128 99,108.4762 (0.06) 38096 1
test_banded_dot_perf[10_000-dot] 3,365,791.9994 (>1000.0) 5,034,374.9972 (698.44) 3,495,052.0250 (>1000.0) 345,179.3641 (>1000.0) 3,410,270.5013 (>1000.0) 47,562.5002 (>1000.0) 2;3 286.1188 (0.00) 40 1
test_banded_dot_perf[10_000-banded_dot] 80,417.0013 (137.94) 454,208.9991 (63.01) 119,363.4743 (183.98) 65,435.1952 (620.38) 91,417.0014 (146.27) 38,540.9949 (917.71) 33;33 8,377.7722 (0.01) 350 1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
I'm curious if it's possible to destroy A and make it into A_banded in-place. If it's possible, it doesn't seem trivial. BLAS doesn't have an overwrite_x option, so b can't be destroyed either.
Frankly my time would be better served thinking about how to do this in C at this point.
Also we should probably be benchmarking this against sparse_dot -- this all might be a waste of time?
Also we should probably be benchmarking this against sparse_dot -- this all might be a waste of time?
Well SparseDot doesn't work with batch inputs, but I'm curious. Also I don't think the code is too complex or performing too bad. I don't agree with your sentiment, should be thinking of a C impl. A numba one is more interesting...
Por que não os dois?
Seriously though my feeling is that if we're putting this stuff into a PyMC model the code has to be ultra-performant. It's going to be called umptillion times, the inner-loop of a PDE solver times the MCMC loop.
I'll work on the numba dispatch next at any rate
By that argument you can't really add any specialized Op that doesn't have a C implementation (unless it's replacing an Op that also doesn't have C implementation).
Ignoring the general user, you can have code to decide whether to use this Op or not based on the size (or a rewrite). Also how are you sampling / getting A, can you avoid the boxing/unboxing of the diagonals?
well the point is the specialization isn't adding anything over good ol' pt.dot (yet!), except for really huge matrices.
Numba benchmarks. I think the numba dispatch could be optimized more to improve it.
--------------------------------------------------------------------------------------------------------------------------- benchmark: 16 tests ---------------------------------------------------------------------------------------------------------------------------
Name (time in ns) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_banded_dot_perf[10-FAST_RUN-dot] 582.9934 (1.0) 5,582.9951 (1.0) 669.4463 (1.0) 72.6255 (1.0) 666.0048 (1.0) 41.9968 (2.00) 2773;2451 1,493,771.7737 (1.0) 38524 1
test_banded_dot_perf[10-NUMBA-dot] 1,124.9940 (1.93) 44,874.9997 (8.04) 1,282.0732 (1.92) 234.5687 (3.23) 1,250.0022 (1.88) 42.0041 (2.00) 2930;7076 779,986.6647 (0.52) 119404 1
test_banded_dot_perf[10-FAST_RUN-banded_dot] 3,332.9998 (5.72) 39,458.0056 (7.07) 3,538.1992 (5.29) 325.6973 (4.48) 3,500.0048 (5.26) 123.9969 (5.91) 1430;1601 282,629.6514 (0.19) 84503 1
test_banded_dot_perf[10-NUMBA-banded_dot] 1,250.0022 (2.14) 1,027,250.0003 (184.00) 1,420.5034 (2.12) 2,789.3330 (38.41) 1,415.9959 (2.13) 42.0041 (2.00) 180;6730 703,975.7880 (0.47) 136352 1
test_banded_dot_perf[100-FAST_RUN-dot] 1,051.9998 (1.80) 10,291.7493 (1.84) 1,110.6063 (1.66) 81.3833 (1.12) 1,104.2503 (1.66) 20.9984 (1.0) 395;4285 900,409.0542 (0.60) 191976 4
test_banded_dot_perf[100-NUMBA-dot] 2,957.9969 (5.07) 57,625.0022 (10.32) 3,456.9208 (5.16) 389.0581 (5.36) 3,417.0007 (5.13) 125.0010 (5.95) 553;3677 289,274.7781 (0.19) 105731 1
test_banded_dot_perf[100-FAST_RUN-banded_dot] 3,666.9953 (6.29) 41,667.0009 (7.46) 3,981.1827 (5.95) 360.0852 (4.96) 3,957.9973 (5.94) 165.9937 (7.91) 2100;2312 251,181.6399 (0.17) 121833 1
test_banded_dot_perf[100-NUMBA-banded_dot] 1,666.9946 (2.86) 28,708.0038 (5.14) 1,867.6628 (2.79) 241.6104 (3.33) 1,833.9997 (2.75) 42.0041 (2.00) 1994;13553 535,428.5652 (0.36) 130430 1
test_banded_dot_perf[1000-FAST_RUN-dot] 13,874.9965 (23.80) 50,374.9980 (9.02) 15,947.8134 (23.82) 1,284.7185 (17.69) 16,540.9947 (24.84) 2,000.9975 (95.29) 5008;216 62,704.5209 (0.04) 40540 1
test_banded_dot_perf[1000-NUMBA-dot] 366,542.0008 (628.72) 944,250.0013 (169.13) 394,639.3455 (589.50) 47,299.9619 (651.29) 375,083.4985 (563.18) 29,312.4976 (>1000.0) 238;221 2,533.9592 (0.00) 2104 1
test_banded_dot_perf[1000-FAST_RUN-banded_dot] 8,625.0002 (14.79) 136,833.9945 (24.51) 10,687.2292 (15.96) 1,682.1809 (23.16) 11,165.9974 (16.77) 1,873.9956 (89.24) 363;264 93,569.6219 (0.06) 62337 1
test_banded_dot_perf[1000-NUMBA-banded_dot] 6,292.0008 (10.79) 37,042.0021 (6.63) 7,550.0573 (11.28) 633.0026 (8.72) 7,709.0008 (11.57) 249.9946 (11.91) 17365;17601 132,449.3267 (0.09) 84804 1
test_banded_dot_perf[10_000-FAST_RUN-dot] 3,360,708.0040 (>1000.0) 4,747,791.0048 (850.40) 3,401,666.0148 (>1000.0) 99,592.4120 (>1000.0) 3,388,875.0004 (>1000.0) 21,333.5006 (>1000.0) 3;14 293.9736 (0.00) 200 1
test_banded_dot_perf[10_000-NUMBA-dot] 75,194,832.9998 (>1000.0) 79,603,917.0018 (>1000.0) 77,050,570.3840 (>1000.0) 1,533,475.7321 (>1000.0) 77,686,790.9954 (>1000.0) 2,637,947.7476 (>1000.0) 6;0 12.9785 (0.00) 13 1
test_banded_dot_perf[10_000-FAST_RUN-banded_dot] 87,583.0010 (150.23) 547,832.9986 (98.13) 151,780.8135 (226.73) 56,039.1835 (771.62) 152,041.0005 (228.29) 47,916.0062 (>1000.0) 660;123 6,588.4480 (0.00) 2659 1
test_banded_dot_perf[10_000-NUMBA-banded_dot] 81,334.0012 (139.51) 461,791.9985 (82.71) 102,538.0625 (153.17) 32,585.4856 (448.68) 89,959.0013 (135.07) 4,094.2505 (194.98) 896;1290 9,752.4761 (0.01) 7569 1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Numba regular dot is doing something crappy. If you make it call gemv yourself you should make it as good as the c dot (which I guess is being rewritten into Gemv)?
Otherwise this still shows regular C dot n=100 winning over numba banded dot.
You may want to add a n=256 or something, to see where the threshold is, big jump from iffy n=100 to clearly better at n=1000.
Or make a plot instead?
Look how shitty numba dot is, moving on...
Looks like numba is a constant factor faster, and this specialized dot is better after about size 200x200
That's much more palatable.
The difference between numba/python gbmv is also what you should expect to see if you implemented gbmv in C so you don't have to wonder.
Thinking emoji
The problem in the timings was some copies being done both in python and numba mode. Here are the updated timings. They're essentially the same except on the low-end, where getting rid of the python overhead is giving numba a small consistent speed bump.
I'd like to call this one done for now, although there are three major things that are left to do:
- Enable GEMV rewrites in NUMBA and re-use that machinery to allow all arguments to the numba xgemv fuction. Right now I'm forcing alpha=1, beta=0.
- Split off the code that converts a dense banded matrix into the banded matrix form into a separate Op. Then we can add a rewrite to do things like lift that outside of scan, for example. More importantly, we can;
- Introduce a rewrite that converts GEMV(BandedMatrix(A), x, ...) into BandedGEMV(BandedMatrix(A), x, ...). The existing BandedDot can become BandedGEMV and we can use all arguments.
I want to merge this then do these 3 things because I want to do #1418 first, and put the resulting function into the new _BLAS.py file in this PR. Enable the relevant rewrites, then revisit this code.
I also need to think about how to handle the splitting out of the BandedMatrix Op, because it destroys information about how many rows the input matrix has (gemv needs to know this).
Codecov Report
Attention: Patch coverage is 71.90083% with 34 lines in your changes missing coverage. Please review.
Project coverage is 82.09%. Comparing base (
d10f245) to head (976422f).
:x: Your patch check has failed because the patch coverage (71.90%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.
Additional details and impacted files
@@ Coverage Diff @@
## main #1416 +/- ##
==========================================
- Coverage 82.12% 82.09% -0.03%
==========================================
Files 211 213 +2
Lines 49757 49878 +121
Branches 8819 8826 +7
==========================================
+ Hits 40862 40949 +87
- Misses 6715 6746 +31
- Partials 2180 2183 +3
| Files with missing lines | Coverage Δ | |
|---|---|---|
| pytensor/link/numba/dispatch/basic.py | 79.08% <ø> (ø) |
|
| pytensor/link/numba/dispatch/linalg/_BLAS.py | 100.00% <100.00%> (ø) |
|
| pytensor/link/numba/dispatch/slinalg.py | 70.10% <75.00%> (+0.34%) |
:arrow_up: |
| pytensor/tensor/slinalg.py | 93.00% <90.00%> (-0.19%) |
:arrow_down: |
| pytensor/link/numba/dispatch/linalg/dot/banded.py | 46.00% <46.00%> (ø) |
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
@ricardoV94 since #1418 got resolved without adding the GEMV rewrite to numba, how should I handle expanding this Op to include rank-1 updates?
We may still link directly to blas for the full update, not sure numba does it besides dispatching the matrix/vector dot part
I would start by benchmarking directly with numba to see if we get a speedup from calling the fused gemv op directly or if numba does it (the regular one, it for sure doesn't do it for gbmv)
I just pushed a major refactor to this PR, which:
- Renamed
BandedDottoBandedGEMV(which is what it actually is, though the actual routine is called GBMV, i thought banded GEMV was more clear) - Add support for all GBMV arguments (A, x, y, alpha, beta) in BandedGEMV.
- Adjusts the numba overload accordingly
- Adds a numba overload for GEMV itself. Note that this will never be used, because we don't include
BlasOptin the numba rewrites.
Regarding point (4), here are the benchmarks using the numba GEMV overload vs what we current get with mode="NUMBA":
------------------------------------------------------------------------------------------------ benchmark: 6 tests ------------------------------------------------------------------------------------------------
Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_numba_gemv_benchmark[numba-10] 6.7500 (1.0) 109.3330 (1.32) 7.9952 (1.0) 1.6793 (1.09) 7.8750 (1.0) 0.2910 (1.40) 878;1881 125.0747 (1.0) 29888 1
test_numba_gemv_benchmark[numba+blas-10] 7.5000 (1.11) 82.7080 (1.0) 8.0753 (1.01) 1.9040 (1.24) 7.8750 (1.00) 0.2080 (1.0) 252;986 123.8337 (0.99) 14185 1
test_numba_gemv_benchmark[numba-100] 8.0410 (1.19) 129.5830 (1.57) 10.3897 (1.30) 1.7014 (1.11) 10.5830 (1.34) 1.2920 (6.21) 157;122 96.2488 (0.77) 29963 1
test_numba_gemv_benchmark[numba+blas-100] 7.9170 (1.17) 89.1250 (1.08) 10.1977 (1.28) 1.5343 (1.0) 10.4170 (1.32) 1.2080 (5.81) 218;165 98.0610 (0.78) 32129 1
test_numba_gemv_benchmark[numba-1000] 22.2920 (3.30) 708.7500 (8.57) 25.6014 (3.20) 6.1947 (4.04) 24.9170 (3.16) 1.6670 (8.01) 229;1198 39.0604 (0.31) 18824 1
test_numba_gemv_benchmark[numba+blas-1000] 21.3330 (3.16) 186.1250 (2.25) 24.8268 (3.11) 4.2993 (2.80) 24.0830 (3.06) 1.8750 (9.01) 448;739 40.2790 (0.32) 17992 1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
It's about the same or maybe slightly better, but at the cost that we can't cache the compiled function anymore due to the function pointer.
Also note that the test is very sensitive to the detection of the alpha parameter. I had to write:
alpha * (A @ x) + beta * y
In order for the GEMV rewrite to correctly find alpha. If it fails to find alpha, mode="NUMBA" significantly out-performs the GEMV call.