pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Add Ops for LU Factorization

Open jessegrabowski opened this issue 10 months ago • 2 comments

Description

This PR will add the following Ops:

  • [x] lu
  • [x] lu_factor
  • [x] lu_solve

As well as dispatches for numba/jax (and maybe torch, though help is welcome there).

The reason for wanting these is that it will make the gradients of solve faster. I think this is a major reason why jax has faster gradients than us (at least when solve is implicated). They route everything to lu_solve(lu_factor(A), b), and reuse lu_factor(A) in the backward pass.

Related Issue

  • [ ] Closes #
  • [ ] Related to #

Checklist

Type of change

  • [ ] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

📚 Documentation preview 📚: https://pytensor--1218.org.readthedocs.build/en/1218/

jessegrabowski avatar Feb 18 '25 15:02 jessegrabowski

Any benchmarks on solve written with these Ops?

ricardoV94 avatar Feb 28 '25 10:02 ricardoV94

Working on that next, just ironing out some bugs in the lu_solve Op (which is what all this is building towards)

jessegrabowski avatar Feb 28 '25 10:02 jessegrabowski

Codecov Report

Attention: Patch coverage is 74.80720% with 98 lines in your changes missing coverage. Please review.

Project coverage is 82.07%. Comparing base (676296c) to head (1e42069). Report is 172 commits behind head on main.

Files with missing lines Patch % Lines
...ensor/link/numba/dispatch/linalg/solve/lu_solve.py 41.50% 31 Missing :warning:
...sor/link/numba/dispatch/linalg/decomposition/lu.py 66.66% 20 Missing :warning:
...k/numba/dispatch/linalg/decomposition/lu_factor.py 56.81% 19 Missing :warning:
pytensor/link/numba/dispatch/slinalg.py 71.42% 8 Missing and 6 partials :warning:
pytensor/tensor/slinalg.py 92.20% 6 Missing and 6 partials :warning:
pytensor/link/jax/dispatch/slinalg.py 92.30% 1 Missing and 1 partial :warning:
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1218      +/-   ##
==========================================
+ Coverage   82.05%   82.07%   +0.02%     
==========================================
  Files         203      206       +3     
  Lines       48863    49174     +311     
  Branches     8695     8720      +25     
==========================================
+ Hits        40093    40359     +266     
- Misses       6619     6656      +37     
- Partials     2151     2159       +8     
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/basic.py 80.38% <ø> (ø)
...tensor/link/numba/dispatch/linalg/solve/general.py 51.72% <100.00%> (+7.17%) :arrow_up:
pytensor/tensor/blockwise.py 90.40% <100.00%> (+4.75%) :arrow_up:
pytensor/tensor/elemwise.py 90.01% <ø> (+0.41%) :arrow_up:
pytensor/link/jax/dispatch/slinalg.py 85.33% <92.30%> (+3.70%) :arrow_up:
pytensor/tensor/slinalg.py 93.10% <92.20%> (-0.30%) :arrow_down:
pytensor/link/numba/dispatch/slinalg.py 69.76% <71.42%> (+0.66%) :arrow_up:
...k/numba/dispatch/linalg/decomposition/lu_factor.py 56.81% <56.81%> (ø)
...sor/link/numba/dispatch/linalg/decomposition/lu.py 66.66% <66.66%> (ø)
...ensor/link/numba/dispatch/linalg/solve/lu_solve.py 41.50% <41.50%> (ø)

... and 2 files with indirect coverage changes

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Apr 19 '25 15:04 codecov[bot]