Pinv
Proposed changes
Add Moore-Penrose Pseudo Inverse function. Inspired by the recent PRs from @nicolov in adding svd and inv, this PR adds the pinv primitive
Tests
Ran some tests locally and included them in PR
>>> import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3.0] ])
... A_pinv = mx.linalg.pinv(A)
... A @ A_pinv @ A
>>>
>>> A @ A_pinv @ A
array([[1, 2, 1, 1, 9],
[3, 4, 2, 2, 8],
[2, 2, 1, 0.999999, 4],
[5, 6, 7, 2, 3]], dtype=float32)
Re-Tested, and everything looks good
>>>
... import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1, 2], [3, 4], [2, 2], [5, 3.0] ])
... A_pinv = mx.linalg.pinv(A)
... print(A.shape, ",\tallclose? ", mx.allclose(A, A @ A_pinv @ A))
...
... import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3.0] ])
... A_pinv = mx.linalg.pinv(A)
... print(A.shape, ",\tallclose? ", mx.allclose(A, A @ A_pinv @ A))
...
... import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1.0, 2], [3, 4.0]])
... A_pinv = mx.linalg.pinv(A)
... print(A.shape, ",\tallclose? ", mx.allclose(A, A @ A_pinv @ A))
(4, 2) , allclose? array(True, dtype=bool)
(4, 5) , allclose? array(True, dtype=bool)
(2, 2) , allclose? array(True, dtype=bool)
>>>
>>>
Checklist
Put an x in the boxes that apply.
- [x] I have read the CONTRIBUTING document
- [x] I have run
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes - [x] I have added tests that prove my fix is effective or that my feature works
- [x] I have updated the necessary documentation (if needed)
@adhulipa I think this implementation is the wrong way to go about it. Using SVD to compute the pseudo inverse means we don't need a primitive and kernels etc. It is just an op that can reside in the mlx::core::linalg namespace.
Basically something like the following (of the top of my head so ymmv)
def pinv(x):
U, S, V = mx.linalg.svd(x)
return (V[:len(x)].T * 1/S) @ U.T
Ahh I see! I didn’t think about that. Thanks for the review @angeloskath
I suppose we could modify this PR to merge in a Python form one as a first step and then investigate whether a custom kernel is necessary.
Would you recommend such a direction?
The op should be in C++ and then do a binding (we try to keep the C++ and Python APIs reasonably consistent). I think the Python impl from @angeloskath is just intended as pseudo-code for that.
Ah yes that makes sense. I should add the Python api that matches the cpp api for pinv(). I haven’t gotten around to it. Thank you for taking a look folks!
I made a few updates. Still gotta figure out how to fix the cpp op where svd(A) returns u, s, vt where u has same dims as A (when rectangular). This makes the matmul incompatible.
I have a path to green where I need to tweak u to match expected end-shape.
(I’m positive the SVD approach works accurately because I validated it in Python api mlx; and few other langs such as matlab to be certain)
Also can use the PyTorch impl as a reference https://github.com/pytorch/pytorch/blob/2ffab6e663b9c6951048b8c8ba82d2cc5ca5c2fc/aten/src/ATen/native/LinearAlgebra.cpp#L532
just need to get around to it in due time
@adhulipa are you planning to come back to this?
@adhulipa are you planning to return to this one?
Hi @awni yes I will update this one. I am running into an issue where I haven't figured out how to allocate the rectangular array for the output/result array before passing it off to the PINV function.
Apologies for the delay; other priorities took precedence lately.
I think I should be able to dedicate a few hours this weekend -- likely 4-8 hours on 5/11
@awni do you think it’s better to close this PR and reopen against a newer mainline commit? Happy to do so if it helps keep your PR todo list clean
Its more up to you. If you plan to work on it in the near future then you can keep it open (or start a new one if you prefer). If not, I would close it.
I'll keep it open for now. I'll close and re-open if it gets too far behind significantly -- for now these changes are additive; so that's not a risk. It just needs a bit of polish/bugfixing.
Made some progress. Need to fix a few more things.
I am suspecting there's something I need to figure out with how im using mx.linalg.svd(A) or rather in cpp svd_imp() and then the matmuls for getting the pinv
Im seeing
>>> A = mx.array([[1.0, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3] ])
>>> U, S, Vt = mx.linalg.svd(A)
>>> U @ mx.diag(S) @ Vt
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: [matmul] Last dimension of first input with shape (4,4) must match second to last dimension of second input with shape (5,5).
Which seems contradictory to the MLX svd doc
Returns The U, S, and Vt matrices, such that A = U @ diag(S) @ Vt
Of course, I ack MLX mimics the NumPy API and NumPy indeed also produces a similar result. But it looks like they have support for a full_matrices: Bool = True kwarg; which I suppose was designed to help for these types of cases
A = np.array([[1, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3] ])
U, S, Vt = np.linalg.svd(A, full_matrices=False)
U @ np.diag(S) @ Vt
>>> U @ np.diag(S) @ Vt
array([[1., 2., 1., 1., 9.],
[3., 4., 2., 2., 8.],
[2., 2., 1., 1., 4.],
[5., 6., 7., 2., 3.]])
>>> np.allclose(A, U @ np.diag(S) @ Vt)
True
(Fwiw, without full_matrics=False, error is same as mlx)
>>> U, S, Vt = np.linalg.svd(A)
>>> U @ np.diag(S) @ Vt
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 5 is different from 4)
I think you can just do something like this:
U, S, V = mx.linalg.svd(A)
K = min(A.shape[0], A.shape[1])
Atilde = (U[:, :K] * S) @ V[:K, :]
We could add the slicing as an option like Numpy if it's useful.
Also I would recommend you rebase before making further progress to make it easier to resolve conflicts.
Ah thanks Awni! Will use that ack on rebase.
Small update: Got a local build that correctly computes pinv in most of the tests. Cleaning up some things and polishing up the code.
>>> import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3.0] ])
... A_pinv = mx.linalg.pinv(A)
>>> A_pinv
array([[2.25371e-07, -1, 2, -7.40725e-08],
[-0.408602, 1.2043, -1.44086, -0.064516],
[0.363441, -0.48172, -0.0236563, 0.225806],
[-0.346237, 0.873119, -0.894624, -0.0967741],
[0.2, -0.2, 0.2, 1.00349e-08]], dtype=float32)
>>> A @ A_pinv @ A
array([[1, 2, 1, 1, 9],
[3, 4, 2, 2, 8],
[2, 2, 1, 0.999999, 4],
[5, 6, 7, 2, 3]], dtype=float32)
>>> ans = A @ A_pinv @ A
>>> mx.allclose(A, A @ A_pinv @ A)
array(True, dtype=bool)
Turns out I was incorrectly relying on the computation array graph API instead of computing the actual matrix products (D'oh!). Now I have some code locally using lapack's mm func (such as sgemm) to compute the final pinv product.
Will update this PR soon
Updated the PR. This PR is in a good enough shape for a review from @awni and other MLX folks. Thanks!
Perhaps there's one more thing to check (on my part) in the python tests. Will look into it. But in the meantime, this PR is still good for a review.
Drats. I have another bug to fix. I updated the tests to catch it. Will look into and fix. Essentially, long rectangular matrices have a matmul dim mismatch -- which means I have made an error in the m, n, k calculations and/or slice selections or U/Vt
Fixed the bug for rectangular matrices where M > N 🎉 Will publish commit soon
This PR is ready for a review from Awni, Angelos and other MLX folks. Thanks!
Hi @adhulipa . I think there shouldn't be a primitive for this operation. It can really just be an op in the linalg namespace.
Hi @angeloskath ohh I see. I think I may have misinterpreted something in the thread here then. Particularly what @awni shared after you (@angeloskath) shared that comment earlier.
The op should be in C++ and then do a binding
Is it accurate to say that you meant this should in linalg.cpp where we add something relatively simple
auto outs = linalg.svd(x);
array U = outs[0];
array S = outs[1];
.../// etc..
return (V[:len(x)].T * 1/S) @ U.T // of course, in the cpp variant instead of the python-esque here
Actually, @angeloskath do you mean to say that we don't need a primitive; but all the logic of calling linalg::svd()' ensuring the svd() call is eval()'d and then passing to lapack's sgemm()` function calls all should be in the linalg.cpp file (and namespace)?
It seems like the recommendation here is to keep the core logic intact but just not make this a primitive. Am I understanding that correctly?
I think im starting to understand the motivation behind the c++ op sans primitive recommendation from Angelos. Pardon the roundabout way I needed to understand this 😅
The following change in linalg.cpp does pass the tests. Just checking a few more things before I can publish a new commit.
array pinv(const array& a, StreamOrDevice s /* = {} */) {
....
const auto m = a.shape(-2);
const auto n = a.shape(-1);
const auto k = std::min(m, n);
const auto rank = a.ndim();
auto outs = linalg::svd(a, Device::cpu);
auto U = outs[0];
auto S = outs[1];
auto Vt = outs[2];
....
....
const auto U_slice = slice(U, {0, 0}, {m, k});
const auto Vt_slice = slice(Vt, {0, 0}, {k, n});
return matmul(matmul(transpose(Vt_slice), diag(1.0/S)), transpose(U_slice));;
}
@angeloskath @awni -- question: do you folks feel like this is in a good shape for a review? Of course, no rush from my pov; just thought I'd check.
hi @angeloskath @awni -- im curious if you think this PR is good for merge or review? Or perhaps there's another way to build this functionality? I am thankful for all the feedback so far and pointers on the road forward : )
Thank you Angelos! No need to apologize re: delay : ) -- I appreciate the review + feedback. I will address and update the PR this week.
Looks like a test failed. I will investigate
FAIL [0.006s]: test_pseudo_inverse (test_linalg.TestLinalg)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/distiller/project/python/tests/test_linalg.py", line 173, in test_pseudo_inverse
self.assertTrue(mx.allclose(A @ A_plus @ A, A, rtol=0, atol=1e-6))
AssertionError: array(False, dtype=bool) is not true
It looks like the failure is a matter of precision
For the square matrix case, it looks like 1e-5 is a passing tolerance limit whereas the existing 1e-6 level is causing a failure
>>> A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
... A_plus = mx.linalg.pinv(A, stream=mx.cpu)
>>> mx.allclose(A @ A_plus @ A, A, rtol=0, atol=1e-5)
array(True, dtype=bool)
>>> mx.allclose(A @ A_plus @ A, A, rtol=0, atol=1e-6)
array(False, dtype=bool)
Similarly, for a 2x3x3 matrix, it seems to be within 1e-3 but not within 1e-6
>>> B = A - 100
... AB = mx.stack([A, B])
... pinvs = mx.linalg.pinv(AB, stream=mx.cpu)
>>>
>>> pinvs
array([[[0.313084, -0.0467289, -0.107477],
[0.364486, -0.158878, -0.0654206],
[-0.0140187, 0.121495, 0.0794393]],
[[0.0915965, -0.0186531, -0.0762813],
[0.169056, -0.134106, -0.0378953],
[-0.274593, 0.154526, 0.11614]]], dtype=float32)
>>> for M, M_plus in zip(AB, pinvs):
... print(M @ M_plus @ M)
... print(mx.allclose(M @ M_plus @ M, M, rtol=0, atol=1e-3))
array([[1, 2, 3],
[6, -5, 4],
[-9, 8, 7]], dtype=float32)
array(True, dtype=bool)
array([[-98.9998, -97.9998, -96.9998],
[-93.9997, -105, -95.9998],
[-109, -91.9998, -92.9998]], dtype=float32)
array(True, dtype=bool)
@angeloskath hey there Angelos, curious if you or others on MLX could come back to this PR sometime? Thanks!