pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Add linalg Ops to MLX backend

Open cetagostini opened this issue 1 month ago • 1 comments

Closes #1693

cetagostini avatar Oct 27 '25 20:10 cetagostini

Codecov Report

:x: Patch coverage is 92.77108% with 6 lines in your changes missing coverage. Please review. :white_check_mark: Project coverage is 81.70%. Comparing base (ab5037e) to head (1eb414f). :warning: Report is 10 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/mlx/dispatch/slinalg.py 90.47% 2 Missing and 2 partials :warning:
pytensor/link/mlx/dispatch/nlinalg.py 94.87% 1 Missing and 1 partial :warning:

:x: Your patch status has failed because the patch coverage (92.77%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1700      +/-   ##
==========================================
+ Coverage   81.64%   81.70%   +0.06%     
==========================================
  Files         244      246       +2     
  Lines       53590    53632      +42     
  Branches     9438     9438              
==========================================
+ Hits        43752    43820      +68     
+ Misses       7356     7330      -26     
  Partials     2482     2482              
Files with missing lines Coverage Δ
pytensor/link/mlx/dispatch/__init__.py 100.00% <100.00%> (ø)
pytensor/link/mlx/dispatch/nlinalg.py 94.87% <94.87%> (ø)
pytensor/link/mlx/dispatch/slinalg.py 90.47% <90.47%> (ø)

... and 9 files with indirect coverage changes

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Oct 28 '25 16:10 codecov[bot]

@ricardoV94 @jessegrabowski

The problem is NOT with the blockwise dispatcher or with "useless vmap" errors. The real issue is:

  1. Linalg operations are wrapped in Blockwise: When you call pt.linalg.inv(A), it returns Blockwise(MatrixInverse)(A)
  2. The rewrite system removes Blockwise for matrix inputs: The local_useless_unbatched_blockwise rewrite (part of "fast_run") detects when there are no batch dimensions (batch_ndim == 0) and unwraps the Blockwise, leaving just the core op (MatrixInverse, SVD, etc.)
  3. You are registering handlers at the wrong level: The handlers were registered for the Blockwise-wrapped versions, but after the rewrite, only the core ops remain.

So, if you register MLX funcify handlers for the core ops directly, not for Blockwise-wrapped versions;

@mlx_funcify.register(MatrixInverse)
def mlx_funcify_MatrixInverse(op, node, **kwargs):
    # whatever implementation

Then when there are NO batch dims: The rewrite removes Blockwise → handler processes the core op directly. Or when there ARE batch dims: The Blockwise dispatcher calls mlx_funcify(op.core_op, ...) → same handler is used.

I just test it and works like butter :)

cetagostini avatar Oct 30 '25 11:10 cetagostini

You are registering handlers at the wrong level: The handlers were registered for the Blockwise-wrapped versions, but after the rewrite, only the core ops remain.

I don't think this was happening, we usually don't even have pre-defined Blockwise-wrapped Ops you can grab (we have some like _matmul, but that's the exception). Anyway if it works it works

ricardoV94 avatar Oct 30 '25 11:10 ricardoV94