Add Ops for LU Factorization
Description
This PR will add the following Ops:
- [x] lu
- [x] lu_factor
- [x] lu_solve
As well as dispatches for numba/jax (and maybe torch, though help is welcome there).
The reason for wanting these is that it will make the gradients of solve faster. I think this is a major reason why jax has faster gradients than us (at least when solve is implicated). They route everything to lu_solve(lu_factor(A), b), and reuse lu_factor(A) in the backward pass.
Related Issue
- [ ] Closes #
- [ ] Related to #
Checklist
- [ ] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [ ] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
📚 Documentation preview 📚: https://pytensor--1218.org.readthedocs.build/en/1218/
Any benchmarks on solve written with these Ops?
Working on that next, just ironing out some bugs in the lu_solve Op (which is what all this is building towards)
Codecov Report
Attention: Patch coverage is 74.80720% with 98 lines in your changes missing coverage. Please review.
Project coverage is 82.07%. Comparing base (
676296c) to head (1e42069). Report is 172 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #1218 +/- ##
==========================================
+ Coverage 82.05% 82.07% +0.02%
==========================================
Files 203 206 +3
Lines 48863 49174 +311
Branches 8695 8720 +25
==========================================
+ Hits 40093 40359 +266
- Misses 6619 6656 +37
- Partials 2151 2159 +8
| Files with missing lines | Coverage Δ | |
|---|---|---|
| pytensor/link/numba/dispatch/basic.py | 80.38% <ø> (ø) |
|
| ...tensor/link/numba/dispatch/linalg/solve/general.py | 51.72% <100.00%> (+7.17%) |
:arrow_up: |
| pytensor/tensor/blockwise.py | 90.40% <100.00%> (+4.75%) |
:arrow_up: |
| pytensor/tensor/elemwise.py | 90.01% <ø> (+0.41%) |
:arrow_up: |
| pytensor/link/jax/dispatch/slinalg.py | 85.33% <92.30%> (+3.70%) |
:arrow_up: |
| pytensor/tensor/slinalg.py | 93.10% <92.20%> (-0.30%) |
:arrow_down: |
| pytensor/link/numba/dispatch/slinalg.py | 69.76% <71.42%> (+0.66%) |
:arrow_up: |
| ...k/numba/dispatch/linalg/decomposition/lu_factor.py | 56.81% <56.81%> (ø) |
|
| ...sor/link/numba/dispatch/linalg/decomposition/lu.py | 66.66% <66.66%> (ø) |
|
| ...ensor/link/numba/dispatch/linalg/solve/lu_solve.py | 41.50% <41.50%> (ø) |
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.