pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Implement `BandedDot` `Op`

Open jessegrabowski opened this issue 6 months ago • 26 comments

Description

This PR adds a BandedDot Op that uses gbmv to do matrix-vector multiplication for the case that A is a banded matrix.

In my testing, I found that this case sped up computation significantly. Benchmarking against Pytensor's dot, however, the current implementation is significantly slower:

------------------------------------------------------------------------------------------------- benchmark: 8 tests ------------------------------------------------------------------------------------------------
Name (time in us)                       Min                    Max                  Mean              StdDev                Median                IQR            Outliers           OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_dot_perf[10]                    1.7500 (1.0)          17.3330 (1.0)          1.9054 (1.0)        0.1292 (1.0)          1.9160 (1.0)       0.0420 (1.0)      585;1740  524,831.2234 (1.0)       38401           1
test_banded_dot_perf[10]            19.9580 (11.40)    13,765.1250 (794.16)      32.5111 (17.06)    282.5468 (>1000.0)     20.5830 (10.74)     0.3750 (8.93)        6;349   30,758.7051 (0.06)       3275           1

test_dot_perf[100]                   2.4580 (1.40)         42.5420 (2.45)         2.7856 (1.46)       0.3265 (2.53)         2.7500 (1.44)      0.0420 (1.0)      343;7436  358,988.7425 (0.68)      71429           1
test_banded_dot_perf[100]           19.8330 (11.33)    15,203.3750 (877.13)      30.9185 (16.23)    193.8617 (>1000.0)     20.9580 (10.94)     0.4160 (9.90)      51;3057   32,343.1413 (0.06)      20566           1

test_dot_perf[1000]                 15.0000 (8.57)         61.5000 (3.55)        16.6383 (8.73)       1.4182 (10.98)       17.2920 (9.03)      2.2080 (52.57)     905;126   60,102.3508 (0.11)      18377           1
test_banded_dot_perf[1000]          27.0420 (15.45)       423.8750 (24.45)       32.9042 (17.27)      5.2005 (40.25)       32.6250 (17.03)     0.6250 (14.88)    129;1334   30,391.2634 (0.06)      12501           1

test_dot_perf[10_000]            3,369.4580 (>1000.0)   5,011.3330 (289.12)   3,412.7784 (>1000.0)  119.9981 (928.81)   3,394.5625 (>1000.0)  17.2910 (411.69)       4;25      293.0164 (0.00)        198           1
test_banded_dot_perf[10_000]       109.9170 (62.81)       611.5830 (35.28)      139.2751 (73.10)     52.3002 (404.81)     116.5000 (60.80)    14.0000 (333.33)    472;678    7,180.0341 (0.01)       3386           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

I guess there's some major overhead from doing the diagonal extractions and looking up the blas function in python? This could and should probably be a C Op, but I'm not sure I have time to realistically dig into all that anytime soon. Help wanted, at any rate.

Related Issue

  • [ ] Closes #1415
  • [ ] Related to #1323

Checklist

Type of change

  • [x] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

📚 Documentation preview 📚: https://pytensor--1416.org.readthedocs.build/en/1416/

jessegrabowski avatar May 23 '25 08:05 jessegrabowski

I added trust_input and I also load the BLAS functions once on import and save them. So that should reduce some of the most obvious sources of python overhead. New benchmarks (note that they're in ns now, not us):

------------------------------------------------------------------------------------------------------------------- benchmark: 8 tests -------------------------------------------------------------------------------------------------------------------
Name (time in ns)                                      Min                       Max                      Mean                  StdDev                    Median                     IQR            Outliers             OPS            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_banded_dot_perf[10-dot]                      541.9988 (1.0)          4,292.0001 (1.0)            638.1136 (1.0)           51.0902 (1.0)            625.0011 (1.0)           41.0000 (40.91)    1506;209  1,567,119.1257 (1.0)       15636           1
test_banded_dot_perf[10-banded_dot]            17,500.0005 (32.29)      418,167.0010 (97.43)       18,191.1183 (28.51)      3,829.7598 (74.96)       18,083.0011 (28.93)        167.0014 (166.62)     70;630     54,971.8815 (0.04)      11353           1

test_banded_dot_perf[100-dot]                   1,209.0004 (2.23)        23,959.0008 (5.58)         1,340.3628 (2.10)         103.1441 (2.02)         1,333.0009 (2.13)           1.0023 (1.0)    1217;34675    746,066.6804 (0.48)      88889           1
test_banded_dot_perf[100-banded_dot]           17,542.0009 (32.37)       77,083.9997 (17.96)       18,240.8191 (28.59)      1,230.1810 (24.08)       18,000.0006 (28.80)        250.0001 (249.44)   654;2431     54,822.0996 (0.03)      19018           1

test_banded_dot_perf[1000-dot]                 13,291.9995 (24.52)       49,874.9996 (11.62)       15,195.7498 (23.81)      1,137.7872 (22.27)       15,833.0004 (25.33)      1,832.9993 (>1000.0)  2954;119     65,807.8747 (0.04)      22347           1
test_banded_dot_perf[1000-banded_dot]          24,624.9983 (45.43)       74,874.9990 (17.45)       30,233.2753 (47.38)      1,347.0049 (26.37)       30,125.0002 (48.20)        375.0010 (374.15)   874;1333     33,076.1385 (0.02)      15595           1

test_banded_dot_perf[10_000-dot]            3,394,874.9988 (>1000.0)  5,084,541.9992 (>1000.0)  3,585,834.0104 (>1000.0)  191,227.5142 (>1000.0)  3,558,604.5005 (>1000.0)  199,729.5003 (>1000.0)      16;3        278.8752 (0.00)        192           1
test_banded_dot_perf[10_000-banded_dot]       105,208.0006 (194.11)     389,250.0008 (90.69)      124,879.6041 (195.70)    35,967.3472 (704.00)     110,375.0001 (176.60)     8,343.4998 (>1000.0)   320;440      8,007.7128 (0.01)       2665           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

jessegrabowski avatar May 23 '25 09:05 jessegrabowski

I think the Op is fine, specially if we are not trying to introduce it automatically via rewrites. If we are we may consider the backend (once we have it in numba I suspect it will win for smaller matrices) and/or static shapes if we think the worse-case penalty is still too big

ricardoV94 avatar May 23 '25 10:05 ricardoV94

Benchmark after tuning up the _to_banded_form function:

------------------------------------------------------------------------------------------------------------------- benchmark: 8 tests ------------------------------------------------------------------------------------------------------------------
Name (time in ns)                                      Min                       Max                      Mean                 StdDev                    Median                     IQR            Outliers             OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_banded_dot_perf[10-dot]                      499.9965 (1.0)         55,500.0006 (1.41)           665.4888 (1.0)         390.9718 (1.0)            666.0011 (1.0)           42.0005 (1.00)      31;2639  1,502,654.9287 (1.0)       32129           1
test_banded_dot_perf[10-banded_dot]             2,832.9996 (5.67)        71,957.9984 (1.82)         3,356.9474 (5.04)        782.8860 (2.00)         3,332.9998 (5.00)         332.9988 (7.93)    1874;2239    297,889.6806 (0.20)      32833           1

test_banded_dot_perf[100-dot]                   1,000.0003 (2.00)        58,208.9997 (1.47)         1,191.9862 (1.79)        396.5918 (1.01)         1,166.9981 (1.75)          41.9968 (1.0)      305;3163    838,935.8643 (0.56)      91258           1
test_banded_dot_perf[100-banded_dot]            3,332.9998 (6.67)        39,499.9988 (1.0)          3,874.8349 (5.82)        471.5917 (1.21)         3,875.0004 (5.82)          84.0009 (2.00)   1020;11972    258,075.5142 (0.17)      71008           1

test_banded_dot_perf[1000-dot]                 13,584.0019 (27.17)      118,374.9991 (3.00)        16,143.5130 (24.26)     1,984.1144 (5.07)        16,291.0001 (24.46)      2,042.0011 (48.62)    1390;171     61,944.3861 (0.04)      14202           1
test_banded_dot_perf[1000-banded_dot]           8,167.0005 (16.33)       68,749.9996 (1.74)        10,694.7895 (16.07)     1,131.4230 (2.89)        11,000.0001 (16.52)        416.9997 (9.93)    6811;7582     93,503.4764 (0.06)      32521           1

test_banded_dot_perf[10_000-dot]            3,379,415.9972 (>1000.0)  3,680,959.0019 (93.19)    3,463,207.0645 (>1000.0)  79,485.8545 (203.30)   3,434,124.9993 (>1000.0)  114,541.9992 (>1000.0)       6;0        288.7497 (0.00)         31           1
test_banded_dot_perf[10_000-banded_dot]        93,582.9994 (187.17)     294,458.0010 (7.45)       100,154.2338 (150.50)   22,660.4163 (57.96)       95,479.0012 (143.36)     2,083.4996 (49.61)       10;27      9,984.6004 (0.01)        248           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

jessegrabowski avatar May 23 '25 10:05 jessegrabowski

That looks much better!

ricardoV94 avatar May 23 '25 10:05 ricardoV94

I agree numba will probably be better across the board. I'd really like this Op to win on the 100x100 case, that's already a pretty big matrix. 1000x1000 and 10,000x10,000 doesn't really show up in nature too often

jessegrabowski avatar May 23 '25 10:05 jessegrabowski

100x100 is 1us, you are at the edge of python overhead there. Calling an identity PyTensor function and no trust_input is 300-500ns. Calling np.zeros is like 100-200ns. That means you would basically need to have no python overhead whatsoever

Edit: those are on my machine, don't know about yours

ricardoV94 avatar May 23 '25 10:05 ricardoV94

This is the best I think we can get out of this in python?

    def make_thunk(self, node, storage_map, compute_map, no_recycling, impl):
        kl = self.lower_diags
        ku = self.upper_diags
        if node.outputs[0].dtype == "float64":
            gbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float64")
        else:
            gbmv = scipy_linalg.get_blas_funcs("gbmv", dtype="float32")

        ab_size = kl + ku + 1
        a_storage = storage_map[node.inputs[0]]
        b_storage = storage_map[node.inputs[1]]
        out_storage = storage_map[node.outputs[0]]
        out_computed = compute_map[node.outputs[0]] if compute_map is not None else [False]
        def thunk(
            a_storage=a_storage,
            b_storage=b_storage,
            out_storage=out_storage,
            out_computed=out_computed,
            kl=kl,
            ku=ku,
            ab_size=ab_size,
            gbmv=gbmv,
        ):
            A = a_storage[0]
            b = b_storage[0]
            m, n = A.shape

            ab = np.zeros((ab_size, n), dtype=A.dtype, order="C")
            for i, k in enumerate(range(ku, -kl - 1, -1)):
                if k > 0:
                    ab[i, k:] = diag(A, k=k)
                else:
                    ab[i, :n + k] = diag(A, k=k)

            out_storage[0] = gbmv(m, n, kl, ku, 1, ab, b)
            out_computed[0] = True

        return thunk

ricardoV94 avatar May 23 '25 11:05 ricardoV94

I'm not saying we should do that, but it gives you a lower bound on what to expect from your micro-optimizations

ricardoV94 avatar May 23 '25 11:05 ricardoV94

Here's what the thunk version benchmarks as for me:

------------------------------------------------------------------------------------------------------------------- benchmark: 8 tests ------------------------------------------------------------------------------------------------------------------
Name (time in ns)                                      Min                       Max                      Mean                  StdDev                    Median                    IQR            Outliers             OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_banded_dot_perf[10-dot]                      582.9970 (1.0)          7,208.0002 (1.0)            648.7823 (1.0)          105.4763 (1.0)            625.0011 (1.0)          41.9968 (1.0)       184;252  1,541,349.0560 (1.0)       18434           1
test_banded_dot_perf[10-banded_dot]             2,749.9991 (4.72)        28,665.9997 (3.98)         2,954.8453 (4.55)         350.8606 (3.33)         2,917.0005 (4.67)         42.9973 (1.02)     555;5229    338,427.1940 (0.22)      39868           1

test_banded_dot_perf[100-dot]                   1,042.0008 (1.79)        15,624.9989 (2.17)         1,178.4495 (1.82)         197.8076 (1.88)         1,166.9981 (1.87)         42.0005 (1.00)     512;1917    848,572.6277 (0.55)     100848           1
test_banded_dot_perf[100-banded_dot]            3,166.9988 (5.43)        33,166.9980 (4.60)         3,418.6797 (5.27)         364.1081 (3.45)         3,415.9966 (5.47)         83.0005 (1.98)     826;2615    292,510.5862 (0.19)      65574           1

test_banded_dot_perf[1000-dot]                 13,334.0000 (22.87)       45,625.0018 (6.33)        15,480.3238 (23.86)      1,366.7475 (12.96)       15,957.9977 (25.53)     1,958.0002 (46.62)    1490;223     64,598.1318 (0.04)      20426           1
test_banded_dot_perf[1000-banded_dot]           8,541.9997 (14.65)       50,667.0003 (7.03)        10,089.9543 (15.55)        777.8152 (7.37)        10,416.9994 (16.67)     1,290.9986 (30.74)   11635;128     99,108.4762 (0.06)      38096           1

test_banded_dot_perf[10_000-dot]            3,365,791.9994 (>1000.0)  5,034,374.9972 (698.44)   3,495,052.0250 (>1000.0)  345,179.3641 (>1000.0)  3,410,270.5013 (>1000.0)  47,562.5002 (>1000.0)       2;3        286.1188 (0.00)         40           1
test_banded_dot_perf[10_000-banded_dot]        80,417.0013 (137.94)     454,208.9991 (63.01)      119,363.4743 (183.98)    65,435.1952 (620.38)      91,417.0014 (146.27)   38,540.9949 (917.71)      33;33      8,377.7722 (0.01)        350           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

I'm curious if it's possible to destroy A and make it into A_banded in-place. If it's possible, it doesn't seem trivial. BLAS doesn't have an overwrite_x option, so b can't be destroyed either.

Frankly my time would be better served thinking about how to do this in C at this point.

jessegrabowski avatar May 23 '25 11:05 jessegrabowski

Also we should probably be benchmarking this against sparse_dot -- this all might be a waste of time?

jessegrabowski avatar May 23 '25 11:05 jessegrabowski

Also we should probably be benchmarking this against sparse_dot -- this all might be a waste of time?

Well SparseDot doesn't work with batch inputs, but I'm curious. Also I don't think the code is too complex or performing too bad. I don't agree with your sentiment, should be thinking of a C impl. A numba one is more interesting...

ricardoV94 avatar May 23 '25 11:05 ricardoV94

Por que não os dois?

Seriously though my feeling is that if we're putting this stuff into a PyMC model the code has to be ultra-performant. It's going to be called umptillion times, the inner-loop of a PDE solver times the MCMC loop.

I'll work on the numba dispatch next at any rate

jessegrabowski avatar May 23 '25 11:05 jessegrabowski

By that argument you can't really add any specialized Op that doesn't have a C implementation (unless it's replacing an Op that also doesn't have C implementation).

Ignoring the general user, you can have code to decide whether to use this Op or not based on the size (or a rewrite). Also how are you sampling / getting A, can you avoid the boxing/unboxing of the diagonals?

ricardoV94 avatar May 23 '25 11:05 ricardoV94

well the point is the specialization isn't adding anything over good ol' pt.dot (yet!), except for really huge matrices.

jessegrabowski avatar May 23 '25 11:05 jessegrabowski

Numba benchmarks. I think the numba dispatch could be optimized more to improve it.

--------------------------------------------------------------------------------------------------------------------------- benchmark: 16 tests ---------------------------------------------------------------------------------------------------------------------------
Name (time in ns)                                                Min                        Max                       Mean                    StdDev                     Median                       IQR            Outliers             OPS            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_banded_dot_perf[10-FAST_RUN-dot]                       582.9934 (1.0)           5,582.9951 (1.0)             669.4463 (1.0)             72.6255 (1.0)             666.0048 (1.0)             41.9968 (2.00)    2773;2451  1,493,771.7737 (1.0)       38524           1
test_banded_dot_perf[10-NUMBA-dot]                        1,124.9940 (1.93)         44,874.9997 (8.04)          1,282.0732 (1.92)           234.5687 (3.23)          1,250.0022 (1.88)            42.0041 (2.00)    2930;7076    779,986.6647 (0.52)     119404           1

test_banded_dot_perf[10-FAST_RUN-banded_dot]              3,332.9998 (5.72)         39,458.0056 (7.07)          3,538.1992 (5.29)           325.6973 (4.48)          3,500.0048 (5.26)           123.9969 (5.91)    1430;1601    282,629.6514 (0.19)      84503           1
test_banded_dot_perf[10-NUMBA-banded_dot]                 1,250.0022 (2.14)      1,027,250.0003 (184.00)        1,420.5034 (2.12)         2,789.3330 (38.41)         1,415.9959 (2.13)            42.0041 (2.00)     180;6730    703,975.7880 (0.47)     136352           1


test_banded_dot_perf[100-FAST_RUN-dot]                    1,051.9998 (1.80)         10,291.7493 (1.84)          1,110.6063 (1.66)            81.3833 (1.12)          1,104.2503 (1.66)            20.9984 (1.0)      395;4285    900,409.0542 (0.60)     191976           4
test_banded_dot_perf[100-NUMBA-dot]                       2,957.9969 (5.07)         57,625.0022 (10.32)         3,456.9208 (5.16)           389.0581 (5.36)          3,417.0007 (5.13)           125.0010 (5.95)     553;3677    289,274.7781 (0.19)     105731           1

test_banded_dot_perf[100-FAST_RUN-banded_dot]             3,666.9953 (6.29)         41,667.0009 (7.46)          3,981.1827 (5.95)           360.0852 (4.96)          3,957.9973 (5.94)           165.9937 (7.91)    2100;2312    251,181.6399 (0.17)     121833           1
test_banded_dot_perf[100-NUMBA-banded_dot]                1,666.9946 (2.86)         28,708.0038 (5.14)          1,867.6628 (2.79)           241.6104 (3.33)          1,833.9997 (2.75)            42.0041 (2.00)   1994;13553    535,428.5652 (0.36)     130430           1


test_banded_dot_perf[1000-FAST_RUN-dot]                  13,874.9965 (23.80)        50,374.9980 (9.02)         15,947.8134 (23.82)        1,284.7185 (17.69)        16,540.9947 (24.84)        2,000.9975 (95.29)    5008;216     62,704.5209 (0.04)      40540           1
test_banded_dot_perf[1000-NUMBA-dot]                    366,542.0008 (628.72)      944,250.0013 (169.13)      394,639.3455 (589.50)      47,299.9619 (651.29)      375,083.4985 (563.18)      29,312.4976 (>1000.0)   238;221      2,533.9592 (0.00)       2104           1

test_banded_dot_perf[1000-FAST_RUN-banded_dot]            8,625.0002 (14.79)       136,833.9945 (24.51)        10,687.2292 (15.96)        1,682.1809 (23.16)        11,165.9974 (16.77)        1,873.9956 (89.24)     363;264     93,569.6219 (0.06)      62337           1
test_banded_dot_perf[1000-NUMBA-banded_dot]               6,292.0008 (10.79)        37,042.0021 (6.63)          7,550.0573 (11.28)          633.0026 (8.72)          7,709.0008 (11.57)          249.9946 (11.91)  17365;17601    132,449.3267 (0.09)      84804           1


test_banded_dot_perf[10_000-FAST_RUN-dot]             3,360,708.0040 (>1000.0)   4,747,791.0048 (850.40)    3,401,666.0148 (>1000.0)     99,592.4120 (>1000.0)   3,388,875.0004 (>1000.0)     21,333.5006 (>1000.0)      3;14        293.9736 (0.00)        200           1
test_banded_dot_perf[10_000-NUMBA-dot]               75,194,832.9998 (>1000.0)  79,603,917.0018 (>1000.0)  77,050,570.3840 (>1000.0)  1,533,475.7321 (>1000.0)  77,686,790.9954 (>1000.0)  2,637,947.7476 (>1000.0)       6;0         12.9785 (0.00)         13           1

test_banded_dot_perf[10_000-FAST_RUN-banded_dot]         87,583.0010 (150.23)      547,832.9986 (98.13)       151,780.8135 (226.73)      56,039.1835 (771.62)      152,041.0005 (228.29)      47,916.0062 (>1000.0)   660;123      6,588.4480 (0.00)       2659           1
test_banded_dot_perf[10_000-NUMBA-banded_dot]            81,334.0012 (139.51)      461,791.9985 (82.71)       102,538.0625 (153.17)      32,585.4856 (448.68)       89,959.0013 (135.07)       4,094.2505 (194.98)   896;1290      9,752.4761 (0.01)       7569           1

---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

jessegrabowski avatar May 24 '25 08:05 jessegrabowski

Numba regular dot is doing something crappy. If you make it call gemv yourself you should make it as good as the c dot (which I guess is being rewritten into Gemv)?

Otherwise this still shows regular C dot n=100 winning over numba banded dot.

You may want to add a n=256 or something, to see where the threshold is, big jump from iffy n=100 to clearly better at n=1000.

Or make a plot instead?

ricardoV94 avatar May 24 '25 08:05 ricardoV94

Look how shitty numba dot is, moving on... image

Looks like numba is a constant factor faster, and this specialized dot is better after about size 200x200 image

jessegrabowski avatar May 24 '25 09:05 jessegrabowski

That's much more palatable.

The difference between numba/python gbmv is also what you should expect to see if you implemented gbmv in C so you don't have to wonder.

ricardoV94 avatar May 24 '25 09:05 ricardoV94

image

Thinking emoji

jessegrabowski avatar May 24 '25 09:05 jessegrabowski

The problem in the timings was some copies being done both in python and numba mode. Here are the updated timings. They're essentially the same except on the low-end, where getting rid of the python overhead is giving numba a small consistent speed bump.

image image

jessegrabowski avatar May 24 '25 11:05 jessegrabowski

I'd like to call this one done for now, although there are three major things that are left to do:

  1. Enable GEMV rewrites in NUMBA and re-use that machinery to allow all arguments to the numba xgemv fuction. Right now I'm forcing alpha=1, beta=0.
  2. Split off the code that converts a dense banded matrix into the banded matrix form into a separate Op. Then we can add a rewrite to do things like lift that outside of scan, for example. More importantly, we can;
  3. Introduce a rewrite that converts GEMV(BandedMatrix(A), x, ...) into BandedGEMV(BandedMatrix(A), x, ...). The existing BandedDot can become BandedGEMV and we can use all arguments.

I want to merge this then do these 3 things because I want to do #1418 first, and put the resulting function into the new _BLAS.py file in this PR. Enable the relevant rewrites, then revisit this code.

I also need to think about how to handle the splitting out of the BandedMatrix Op, because it destroys information about how many rows the input matrix has (gemv needs to know this).

jessegrabowski avatar May 24 '25 13:05 jessegrabowski

Codecov Report

Attention: Patch coverage is 71.90083% with 34 lines in your changes missing coverage. Please review.

Project coverage is 82.09%. Comparing base (d10f245) to head (976422f).

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/linalg/dot/banded.py 46.00% 27 Missing :warning:
pytensor/tensor/slinalg.py 90.00% 2 Missing and 2 partials :warning:
pytensor/link/numba/dispatch/slinalg.py 75.00% 2 Missing and 1 partial :warning:

:x: Your patch check has failed because the patch coverage (71.90%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1416      +/-   ##
==========================================
- Coverage   82.12%   82.09%   -0.03%     
==========================================
  Files         211      213       +2     
  Lines       49757    49878     +121     
  Branches     8819     8826       +7     
==========================================
+ Hits        40862    40949      +87     
- Misses       6715     6746      +31     
- Partials     2180     2183       +3     
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/basic.py 79.08% <ø> (ø)
pytensor/link/numba/dispatch/linalg/_BLAS.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/slinalg.py 70.10% <75.00%> (+0.34%) :arrow_up:
pytensor/tensor/slinalg.py 93.00% <90.00%> (-0.19%) :arrow_down:
pytensor/link/numba/dispatch/linalg/dot/banded.py 46.00% <46.00%> (ø)
:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar May 24 '25 13:05 codecov[bot]

@ricardoV94 since #1418 got resolved without adding the GEMV rewrite to numba, how should I handle expanding this Op to include rank-1 updates?

jessegrabowski avatar May 28 '25 07:05 jessegrabowski

We may still link directly to blas for the full update, not sure numba does it besides dispatching the matrix/vector dot part

ricardoV94 avatar May 28 '25 07:05 ricardoV94

I would start by benchmarking directly with numba to see if we get a speedup from calling the fused gemv op directly or if numba does it (the regular one, it for sure doesn't do it for gbmv)

ricardoV94 avatar May 28 '25 07:05 ricardoV94

I just pushed a major refactor to this PR, which:

  1. Renamed BandedDot to BandedGEMV (which is what it actually is, though the actual routine is called GBMV, i thought banded GEMV was more clear)
  2. Add support for all GBMV arguments (A, x, y, alpha, beta) in BandedGEMV.
  3. Adjusts the numba overload accordingly
  4. Adds a numba overload for GEMV itself. Note that this will never be used, because we don't include BlasOpt in the numba rewrites.

Regarding point (4), here are the benchmarks using the numba GEMV overload vs what we current get with mode="NUMBA":

------------------------------------------------------------------------------------------------ benchmark: 6 tests ------------------------------------------------------------------------------------------------
Name (time in us)                                  Min                 Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_numba_gemv_benchmark[numba-10]             6.7500 (1.0)      109.3330 (1.32)      7.9952 (1.0)      1.6793 (1.09)      7.8750 (1.0)      0.2910 (1.40)     878;1881      125.0747 (1.0)       29888           1
test_numba_gemv_benchmark[numba+blas-10]        7.5000 (1.11)      82.7080 (1.0)       8.0753 (1.01)     1.9040 (1.24)      7.8750 (1.00)     0.2080 (1.0)       252;986      123.8337 (0.99)      14185           1

test_numba_gemv_benchmark[numba-100]            8.0410 (1.19)     129.5830 (1.57)     10.3897 (1.30)     1.7014 (1.11)     10.5830 (1.34)     1.2920 (6.21)      157;122       96.2488 (0.77)      29963           1
test_numba_gemv_benchmark[numba+blas-100]       7.9170 (1.17)      89.1250 (1.08)     10.1977 (1.28)     1.5343 (1.0)      10.4170 (1.32)     1.2080 (5.81)      218;165       98.0610 (0.78)      32129           1


test_numba_gemv_benchmark[numba-1000]          22.2920 (3.30)     708.7500 (8.57)     25.6014 (3.20)     6.1947 (4.04)     24.9170 (3.16)     1.6670 (8.01)     229;1198       39.0604 (0.31)      18824           1
test_numba_gemv_benchmark[numba+blas-1000]     21.3330 (3.16)     186.1250 (2.25)     24.8268 (3.11)     4.2993 (2.80)     24.0830 (3.06)     1.8750 (9.01)      448;739       40.2790 (0.32)      17992           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

It's about the same or maybe slightly better, but at the cost that we can't cache the compiled function anymore due to the function pointer.

Also note that the test is very sensitive to the detection of the alpha parameter. I had to write:

alpha * (A @ x) + beta * y

In order for the GEMV rewrite to correctly find alpha. If it fails to find alpha, mode="NUMBA" significantly out-performs the GEMV call.

jessegrabowski avatar Jun 26 '25 04:06 jessegrabowski