Add linalg Ops to MLX backend
Closes #1693
Codecov Report
:x: Patch coverage is 92.77108% with 6 lines in your changes missing coverage. Please review.
:white_check_mark: Project coverage is 81.70%. Comparing base (ab5037e) to head (1eb414f).
:warning: Report is 10 commits behind head on main.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| pytensor/link/mlx/dispatch/slinalg.py | 90.47% | 2 Missing and 2 partials :warning: |
| pytensor/link/mlx/dispatch/nlinalg.py | 94.87% | 1 Missing and 1 partial :warning: |
:x: Your patch status has failed because the patch coverage (92.77%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.
Additional details and impacted files
@@ Coverage Diff @@
## main #1700 +/- ##
==========================================
+ Coverage 81.64% 81.70% +0.06%
==========================================
Files 244 246 +2
Lines 53590 53632 +42
Branches 9438 9438
==========================================
+ Hits 43752 43820 +68
+ Misses 7356 7330 -26
Partials 2482 2482
| Files with missing lines | Coverage Δ | |
|---|---|---|
| pytensor/link/mlx/dispatch/__init__.py | 100.00% <100.00%> (ø) |
|
| pytensor/link/mlx/dispatch/nlinalg.py | 94.87% <94.87%> (ø) |
|
| pytensor/link/mlx/dispatch/slinalg.py | 90.47% <90.47%> (ø) |
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
@ricardoV94 @jessegrabowski
The problem is NOT with the blockwise dispatcher or with "useless vmap" errors. The real issue is:
- Linalg operations are wrapped in Blockwise: When you call pt.linalg.inv(A), it returns Blockwise(MatrixInverse)(A)
- The rewrite system removes Blockwise for matrix inputs: The local_useless_unbatched_blockwise rewrite (part of "fast_run") detects when there are no batch dimensions (batch_ndim == 0) and unwraps the Blockwise, leaving just the core op (MatrixInverse, SVD, etc.)
- You are registering handlers at the wrong level: The handlers were registered for the Blockwise-wrapped versions, but after the rewrite, only the core ops remain.
So, if you register MLX funcify handlers for the core ops directly, not for Blockwise-wrapped versions;
@mlx_funcify.register(MatrixInverse)
def mlx_funcify_MatrixInverse(op, node, **kwargs):
# whatever implementation
Then when there are NO batch dims: The rewrite removes Blockwise → handler processes the core op directly. Or when there ARE batch dims: The Blockwise dispatcher calls mlx_funcify(op.core_op, ...) → same handler is used.
I just test it and works like butter :)
You are registering handlers at the wrong level: The handlers were registered for the Blockwise-wrapped versions, but after the rewrite, only the core ops remain.
I don't think this was happening, we usually don't even have pre-defined Blockwise-wrapped Ops you can grab (we have some like _matmul, but that's the exception). Anyway if it works it works