pytensor
pytensor copied to clipboard
Rewrite inverse for triangular matrix
Description
We add a rewrite for matrix inversion when the matrix is triangular.
We check three conditions:
- If there is a tag which is upper/lower triangular
- If the
OpisTri - If the
OpisCholesky
Related Issue
- [ ] Closes #
- [x] Related to #573
Checklist
- [x] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [x] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [x] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
π Documentation preview π: https://pytensor--1612.org.readthedocs.build/en/1612/
Hi @jessegrabowski , I haven't added a test yet, but if this approach is valid, I can add one. Please let me know. Thank you! π
CC: @theorashid , @ColtAllen
Can you post some timings showing that this is advantageous? inv(A) and solve(A, Eye) are basically the same thing. I recognize that there is some advantage because you're using the specialized solve, but I'd like to see what it gives us.
Can you also time solve_triangular(A, Eye) against directly using DTRTRI? That's the specialized LAPACK routine for the inverse of a triangular matrix (as opposed to a solve).
Thank you, @jessegrabowski . I'd be interested in seeing those numbers as well. Let me do that study and report back. Thanks, again π
Hi @jessegrabowski , I did the study that was suggested. I used the underlying perform methods to do the benchmarks for inv(), solve() and solve_triangular().
| Size | inv() | solve() | solve_tri() | dtrtri() | (inv/solve_tri) | (solve_tri/dtrtri) |
|---|---|---|---|---|---|---|
| 50 | 0.01832s | 0.00713s | 0.01206s | 0.00096s | 1.52x | 12.59x |
| 100 | 0.02006s | 0.01739s | 0.00905s | 0.00358s | 2.22x | 2.53x |
| 250 | 0.09416s | 0.10159s | 0.04149s | 0.06596s | 2.27x | 0.63x |
| 500 | 0.43439s | 0.57095s | 0.12388s | 0.07183s | 3.51x | 1.72x |
| 750 | 1.07380s | 1.16674s | 0.46847s | 0.18619s | 2.29x | 2.52x |
| 1000 | 2.64248s | 2.70513s | 1.07507s | 0.40498s | 2.46x | 2.65x |
| 2000 | 22.21642s | 27.74535s | 8.88119s | 2.54412s | 2.50x | 3.49x |
On average, we see ~2X improvement when using solve_triangular() and then a further ~2X improvement when using dtrtri(). Please let me know your thoughts. Thank you π
Click to view benchmarking code
import timeit
import numpy as np
import scipy.linalg
from scipy.linalg.lapack import dtrtri
matrix_sizes = [50, 100, 250, 500, 750, 1000, 2000]
n_repeats = 100
results = {}
for size in matrix_sizes:
print(f"Running for size {size}x{size}...")
A_tril = np.tril(np.random.rand(size, size))
A_tril[np.diag_indices(size)] += 1.0
I = np.eye(size)
t_inv = timeit.timeit(lambda: np.linalg.inv(A_tril), number=n_repeats)
t_solve = timeit.timeit(lambda: np.linalg.solve(A_tril, I), number=n_repeats)
t_solve_tri = timeit.timeit(
lambda: scipy.linalg.solve_triangular(A_tril, I, lower=True),
number=n_repeats
)
A_fortran = np.asfortranarray(A_tril)
t_dtrtri = timeit.timeit(
lambda: dtrtri(A_fortran, lower=1),
number=n_repeats
)
results[size] = {
"inv": t_inv,
"solve": t_solve,
"solve_triangular": t_solve_tri,
"dtrtri": t_dtrtri,
"inv_div_solve_tri": t_inv / t_solve_tri if t_solve_tri > 0 else 0,
"solve_tri_div_dtrtri": t_solve_tri / t_dtrtri if t_dtrtri > 0 else 0
}
That's really awesome! Thanks for doing this study.
Given these results, my suggestion would be to make a TriangularInv Op that subclasses from Inv. In the perform method, use trtri directly (use get_lapack_function to get it over directly importing it so you get the right one based on the dtype of the inputs). We don't have to expose it to users, but you can use it in rewrites.
We can also then add a rewrite that changes TriangularInv(x) @ b to solve_triangular(x, b). But that's another PR for another day :)
Codecov Report
:x: Patch coverage is 84.21053% with 15 lines in your changes missing coverage. Please review.
:white_check_mark: Project coverage is 81.68%. Comparing base (f83c05b) to head (07c48f3).
:warning: Report is 1 commits behind head on main.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| pytensor/tensor/slinalg.py | 79.48% | 7 Missing and 1 partial :warning: |
| pytensor/tensor/rewriting/linalg.py | 87.27% | 0 Missing and 7 partials :warning: |
:x: Your patch check has failed because the patch coverage (84.21%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.
Additional details and impacted files
@@ Coverage Diff @@
## main #1612 +/- ##
=======================================
Coverage 81.68% 81.68%
=======================================
Files 244 244
Lines 53549 53642 +93
Branches 9433 9459 +26
=======================================
+ Hits 43741 43819 +78
- Misses 7328 7335 +7
- Partials 2480 2488 +8
| Files with missing lines | Coverage Ξ | |
|---|---|---|
| pytensor/tensor/nlinalg.py | 94.45% <100.00%> (ΓΈ) |
|
| pytensor/tensor/rewriting/linalg.py | 90.62% <87.27%> (-0.47%) |
:arrow_down: |
| pytensor/tensor/slinalg.py | 90.89% <79.48%> (-0.51%) |
:arrow_down: |
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
Hi @jessegrabowski , could you please review when you have the time ? I do checks similar to this PR, except, of course, we use the lapack solver. I'm also happy to include other operations like det and eig into this PR, if you recommend.
Also, if helpful, I can start a separate issue to track TriangularInv(x) @ b.
Please let me know. Thanks!
Also, if helpful, I can start a separate issue to track TriangularInv(x) @ b.
Sure, feel free to open an issue. If you prefer to wait for this to be merged, that's fine too.
Thanks for the ping! This is looking really amazing, and it's getting very close π₯³
Thank you, @jessegrabowski ! π I took a stab at the comments. Notably, I added triu and tril checks. However, when I tried adding the pt.tri check, it seems to get const folded so doesn't trigger the re-write. Happy to be corrected.
Please let me know your thoughts. Thanks, again! π
So you know, you can run mypy locally from inside the pytensor project folder with python scripts/run_mypy.py --verbose. You then get a huge horrible read-out that you have to sift though for the list of recent failures. I strongly recommend clearing the terminal before you run it.
For the failed float32 test, make sure you set the atol and rtol much more relaxed when config.floatX is float32. Check the other tests to see what we do. We need to think of a better way to test linalg routines at half-precision...
Thank you, @jessegrabowski . Yes, sorry, I realized belatedly about the mypy check
I ran it locally and it seems that I'm fighting this error:
[pytensor/tensor/slinalg.py]
pytensor/tensor/slinalg.py:958: error [assignment]: Incompatible types in assignment (expression has type "tuple[str, str, str]", base class "MatrixInverse" defined the type as "tuple[()]")
Should I redefine the __props__ type in TriangularInv ?
For the test failure, I had tried following the idiom seen elsewhere:
np.testing.assert_allclose(
f(a_val, b_val), c_val, rtol=1e-7 if config.floatX == "float64" else 1e-5
)
But perhaps I need to loosen up the requirement (1e-4) for my test. Would that be okay ?
I also realized some of my tests are in the wrong location, the TriangularInv Op class tests should be in tests/tensor/test_slinalg.py and not in rewriting. Sorry about that as well, I'll move them.
Also, a heads up, I have to do some convoluted checks for re-writing triu / tril . I've tested that these do work, in that, they trigger the re-write and return a correct answer, however, I'm not sure if there is an easier way to do check. Please let me know if there is. Thank you. π
Double check the is_triangular check for the LU/QR cases, then I think this is done! Really great work!
Thank you, @jessegrabowski . Really appreciate the help in getting this over the line.
I tried improving test coverage, but Codecov was complaining for a few paths. Do I need to get the test coverage to a 100% or we can take a call there ? I'm also wondering if checking tag, as I do here, is necessary. Could you please let me know ? Thank you! π
Do I need to get the test coverage to a 100%
No
I'm also wondering if checking tag, as I do here, is necessary
It's not necessary, but some Ops do put that flag. So we might as well check for it.