pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Rewrite inverse for triangular matrix

Open asifzubair opened this issue 2 months ago β€’ 15 comments

Description

We add a rewrite for matrix inversion when the matrix is triangular.

We check three conditions:

  • If there is a tag which is upper/lower triangular
  • If the Op is Tri
  • If the Op is Cholesky

Related Issue

  • [ ] Closes #
  • [x] Related to #573

Checklist

Type of change

  • [x] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

πŸ“š Documentation preview πŸ“š: https://pytensor--1612.org.readthedocs.build/en/1612/

asifzubair avatar Sep 16 '25 02:09 asifzubair

Hi @jessegrabowski , I haven't added a test yet, but if this approach is valid, I can add one. Please let me know. Thank you! πŸ™

CC: @theorashid , @ColtAllen

asifzubair avatar Sep 16 '25 02:09 asifzubair

Can you post some timings showing that this is advantageous? inv(A) and solve(A, Eye) are basically the same thing. I recognize that there is some advantage because you're using the specialized solve, but I'd like to see what it gives us.

Can you also time solve_triangular(A, Eye) against directly using DTRTRI? That's the specialized LAPACK routine for the inverse of a triangular matrix (as opposed to a solve).

jessegrabowski avatar Sep 17 '25 11:09 jessegrabowski

Thank you, @jessegrabowski . I'd be interested in seeing those numbers as well. Let me do that study and report back. Thanks, again πŸ™

asifzubair avatar Sep 19 '25 14:09 asifzubair

Hi @jessegrabowski , I did the study that was suggested. I used the underlying perform methods to do the benchmarks for inv(), solve() and solve_triangular().

Size inv() solve() solve_tri() dtrtri() (inv/solve_tri) (solve_tri/dtrtri)
50 0.01832s 0.00713s 0.01206s 0.00096s 1.52x 12.59x
100 0.02006s 0.01739s 0.00905s 0.00358s 2.22x 2.53x
250 0.09416s 0.10159s 0.04149s 0.06596s 2.27x 0.63x
500 0.43439s 0.57095s 0.12388s 0.07183s 3.51x 1.72x
750 1.07380s 1.16674s 0.46847s 0.18619s 2.29x 2.52x
1000 2.64248s 2.70513s 1.07507s 0.40498s 2.46x 2.65x
2000 22.21642s 27.74535s 8.88119s 2.54412s 2.50x 3.49x
image

On average, we see ~2X improvement when using solve_triangular() and then a further ~2X improvement when using dtrtri(). Please let me know your thoughts. Thank you πŸ™

Click to view benchmarking code
import timeit
import numpy as np
import scipy.linalg
from scipy.linalg.lapack import dtrtri

matrix_sizes = [50, 100, 250, 500, 750, 1000, 2000]
n_repeats = 100

results = {}
for size in matrix_sizes:
    print(f"Running for size {size}x{size}...")
    
    A_tril = np.tril(np.random.rand(size, size))
    A_tril[np.diag_indices(size)] += 1.0
    I = np.eye(size)

    t_inv = timeit.timeit(lambda: np.linalg.inv(A_tril), number=n_repeats)
    t_solve = timeit.timeit(lambda: np.linalg.solve(A_tril, I), number=n_repeats)    
    t_solve_tri = timeit.timeit(
        lambda: scipy.linalg.solve_triangular(A_tril, I, lower=True),
        number=n_repeats
    )
    A_fortran = np.asfortranarray(A_tril)
    t_dtrtri = timeit.timeit(
        lambda: dtrtri(A_fortran, lower=1),
        number=n_repeats
    )
    
    results[size] = {
        "inv": t_inv,
        "solve": t_solve,
        "solve_triangular": t_solve_tri,
        "dtrtri": t_dtrtri,
        "inv_div_solve_tri": t_inv / t_solve_tri if t_solve_tri > 0 else 0,
        "solve_tri_div_dtrtri": t_solve_tri / t_dtrtri if t_dtrtri > 0 else 0
    }

asifzubair avatar Sep 22 '25 00:09 asifzubair

That's really awesome! Thanks for doing this study.

Given these results, my suggestion would be to make a TriangularInv Op that subclasses from Inv. In the perform method, use trtri directly (use get_lapack_function to get it over directly importing it so you get the right one based on the dtype of the inputs). We don't have to expose it to users, but you can use it in rewrites.

We can also then add a rewrite that changes TriangularInv(x) @ b to solve_triangular(x, b). But that's another PR for another day :)

jessegrabowski avatar Sep 22 '25 00:09 jessegrabowski

Codecov Report

:x: Patch coverage is 84.21053% with 15 lines in your changes missing coverage. Please review. :white_check_mark: Project coverage is 81.68%. Comparing base (f83c05b) to head (07c48f3). :warning: Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/slinalg.py 79.48% 7 Missing and 1 partial :warning:
pytensor/tensor/rewriting/linalg.py 87.27% 0 Missing and 7 partials :warning:

:x: Your patch check has failed because the patch coverage (84.21%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1612   +/-   ##
=======================================
  Coverage   81.68%   81.68%           
=======================================
  Files         244      244           
  Lines       53549    53642   +93     
  Branches     9433     9459   +26     
=======================================
+ Hits        43741    43819   +78     
- Misses       7328     7335    +7     
- Partials     2480     2488    +8     
Files with missing lines Coverage Ξ”
pytensor/tensor/nlinalg.py 94.45% <100.00%> (ΓΈ)
pytensor/tensor/rewriting/linalg.py 90.62% <87.27%> (-0.47%) :arrow_down:
pytensor/tensor/slinalg.py 90.89% <79.48%> (-0.51%) :arrow_down:
:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Sep 22 '25 15:09 codecov[bot]

Hi @jessegrabowski , could you please review when you have the time ? I do checks similar to this PR, except, of course, we use the lapack solver. I'm also happy to include other operations like det and eig into this PR, if you recommend.

Also, if helpful, I can start a separate issue to track TriangularInv(x) @ b.

Please let me know. Thanks!

asifzubair avatar Sep 29 '25 04:09 asifzubair

Also, if helpful, I can start a separate issue to track TriangularInv(x) @ b.

Sure, feel free to open an issue. If you prefer to wait for this to be merged, that's fine too.

jessegrabowski avatar Oct 01 '25 02:10 jessegrabowski

Thanks for the ping! This is looking really amazing, and it's getting very close πŸ₯³

Thank you, @jessegrabowski ! πŸ™ I took a stab at the comments. Notably, I added triu and tril checks. However, when I tried adding the pt.tri check, it seems to get const folded so doesn't trigger the re-write. Happy to be corrected.

Please let me know your thoughts. Thanks, again! πŸ™

asifzubair avatar Oct 05 '25 20:10 asifzubair

So you know, you can run mypy locally from inside the pytensor project folder with python scripts/run_mypy.py --verbose. You then get a huge horrible read-out that you have to sift though for the list of recent failures. I strongly recommend clearing the terminal before you run it.

jessegrabowski avatar Oct 05 '25 23:10 jessegrabowski

For the failed float32 test, make sure you set the atol and rtol much more relaxed when config.floatX is float32. Check the other tests to see what we do. We need to think of a better way to test linalg routines at half-precision...

jessegrabowski avatar Oct 05 '25 23:10 jessegrabowski

Thank you, @jessegrabowski . Yes, sorry, I realized belatedly about the mypy check

I ran it locally and it seems that I'm fighting this error:

[pytensor/tensor/slinalg.py]
pytensor/tensor/slinalg.py:958: error [assignment]: Incompatible types in assignment (expression has type "tuple[str, str, str]", base class "MatrixInverse" defined the type as "tuple[()]")

Should I redefine the __props__ type in TriangularInv ?

For the test failure, I had tried following the idiom seen elsewhere:

    np.testing.assert_allclose(
        f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
    )

But perhaps I need to loosen up the requirement (1e-4) for my test. Would that be okay ?

I also realized some of my tests are in the wrong location, the TriangularInv Op class tests should be in tests/tensor/test_slinalg.py and not in rewriting. Sorry about that as well, I'll move them.

Also, a heads up, I have to do some convoluted checks for re-writing triu / tril . I've tested that these do work, in that, they trigger the re-write and return a correct answer, however, I'm not sure if there is an easier way to do check. Please let me know if there is. Thank you. πŸ™

asifzubair avatar Oct 05 '25 23:10 asifzubair

Double check the is_triangular check for the LU/QR cases, then I think this is done! Really great work!

Thank you, @jessegrabowski . Really appreciate the help in getting this over the line.

I tried improving test coverage, but Codecov was complaining for a few paths. Do I need to get the test coverage to a 100% or we can take a call there ? I'm also wondering if checking tag, as I do here, is necessary. Could you please let me know ? Thank you! πŸ™

asifzubair avatar Oct 28 '25 02:10 asifzubair

Do I need to get the test coverage to a 100%

No

ricardoV94 avatar Oct 28 '25 15:10 ricardoV94

I'm also wondering if checking tag, as I do here, is necessary

It's not necessary, but some Ops do put that flag. So we might as well check for it.

jessegrabowski avatar Oct 29 '25 01:10 jessegrabowski