mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Pinv

Open adi-dhulipala opened this issue 1 year ago • 30 comments

Proposed changes

Add Moore-Penrose Pseudo Inverse function. Inspired by the recent PRs from @nicolov in adding svd and inv, this PR adds the pinv primitive

Tests

Ran some tests locally and included them in PR

>>> import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3.0] ])
... A_pinv = mx.linalg.pinv(A)
... A @ A_pinv @ A
>>>
>>>  A @ A_pinv @ A
array([[1, 2, 1, 1, 9],
       [3, 4, 2, 2, 8],
       [2, 2, 1, 0.999999, 4],
       [5, 6, 7, 2, 3]], dtype=float32)

Re-Tested, and everything looks good


>>>
... import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1, 2], [3, 4], [2, 2], [5, 3.0] ])
... A_pinv = mx.linalg.pinv(A)
... print(A.shape, ",\tallclose? ", mx.allclose(A, A @ A_pinv @ A))
...
... import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3.0] ])
... A_pinv = mx.linalg.pinv(A)
... print(A.shape, ",\tallclose? ", mx.allclose(A, A @ A_pinv @ A))
...
... import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1.0, 2], [3, 4.0]])
... A_pinv = mx.linalg.pinv(A)
... print(A.shape, ",\tallclose? ", mx.allclose(A, A @ A_pinv @ A))
(4, 2) ,	allclose?  array(True, dtype=bool)
(4, 5) ,	allclose?  array(True, dtype=bool)
(2, 2) ,	allclose?  array(True, dtype=bool)
>>>
>>>

Checklist

Put an x in the boxes that apply.

  • [x] I have read the CONTRIBUTING document
  • [x] I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] I have updated the necessary documentation (if needed)

adi-dhulipala avatar Mar 22 '24 03:03 adi-dhulipala

@adhulipa I think this implementation is the wrong way to go about it. Using SVD to compute the pseudo inverse means we don't need a primitive and kernels etc. It is just an op that can reside in the mlx::core::linalg namespace.

Basically something like the following (of the top of my head so ymmv)

def pinv(x):
     U, S, V = mx.linalg.svd(x)
     return (V[:len(x)].T * 1/S) @ U.T

angeloskath avatar Mar 22 '24 19:03 angeloskath

Ahh I see! I didn’t think about that. Thanks for the review @angeloskath

I suppose we could modify this PR to merge in a Python form one as a first step and then investigate whether a custom kernel is necessary.

Would you recommend such a direction?

adi-dhulipala avatar Mar 24 '24 20:03 adi-dhulipala

The op should be in C++ and then do a binding (we try to keep the C++ and Python APIs reasonably consistent). I think the Python impl from @angeloskath is just intended as pseudo-code for that.

awni avatar Mar 25 '24 22:03 awni

Ah yes that makes sense. I should add the Python api that matches the cpp api for pinv(). I haven’t gotten around to it. Thank you for taking a look folks!

adi-dhulipala avatar Mar 28 '24 20:03 adi-dhulipala

I made a few updates. Still gotta figure out how to fix the cpp op where svd(A) returns u, s, vt where u has same dims as A (when rectangular). This makes the matmul incompatible.

I have a path to green where I need to tweak u to match expected end-shape.

(I’m positive the SVD approach works accurately because I validated it in Python api mlx; and few other langs such as matlab to be certain)

Also can use the PyTorch impl as a reference https://github.com/pytorch/pytorch/blob/2ffab6e663b9c6951048b8c8ba82d2cc5ca5c2fc/aten/src/ATen/native/LinearAlgebra.cpp#L532

just need to get around to it in due time

adi-dhulipala avatar Apr 05 '24 06:04 adi-dhulipala

@adhulipa are you planning to come back to this?

awni avatar Apr 25 '24 03:04 awni

@adhulipa are you planning to return to this one?

awni avatar May 03 '24 20:05 awni

Hi @awni yes I will update this one. I am running into an issue where I haven't figured out how to allocate the rectangular array for the output/result array before passing it off to the PINV function.

Apologies for the delay; other priorities took precedence lately.

I think I should be able to dedicate a few hours this weekend -- likely 4-8 hours on 5/11

adi-dhulipala avatar May 08 '24 03:05 adi-dhulipala

@awni do you think it’s better to close this PR and reopen against a newer mainline commit? Happy to do so if it helps keep your PR todo list clean

adi-dhulipala avatar May 13 '24 04:05 adi-dhulipala

Its more up to you. If you plan to work on it in the near future then you can keep it open (or start a new one if you prefer). If not, I would close it.

awni avatar May 13 '24 13:05 awni

I'll keep it open for now. I'll close and re-open if it gets too far behind significantly -- for now these changes are additive; so that's not a risk. It just needs a bit of polish/bugfixing.

adi-dhulipala avatar May 13 '24 19:05 adi-dhulipala

Made some progress. Need to fix a few more things.

adi-dhulipala avatar May 23 '24 00:05 adi-dhulipala

I am suspecting there's something I need to figure out with how im using mx.linalg.svd(A) or rather in cpp svd_imp() and then the matmuls for getting the pinv

Im seeing

>>> A = mx.array([[1.0, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3] ])
>>> U, S, Vt = mx.linalg.svd(A)
>>> U @ mx.diag(S) @ Vt
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: [matmul] Last dimension of first input with shape (4,4) must match second to last dimension of second input with shape (5,5).

Which seems contradictory to the MLX svd doc

Returns The U, S, and Vt matrices, such that A = U @ diag(S) @ Vt

Of course, I ack MLX mimics the NumPy API and NumPy indeed also produces a similar result. But it looks like they have support for a full_matrices: Bool = True kwarg; which I suppose was designed to help for these types of cases

A = np.array([[1, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3] ])
U, S, Vt = np.linalg.svd(A, full_matrices=False)
U @ np.diag(S) @ Vt

>>> U @ np.diag(S) @ Vt
array([[1., 2., 1., 1., 9.],
       [3., 4., 2., 2., 8.],
       [2., 2., 1., 1., 4.],
       [5., 6., 7., 2., 3.]])

>>> np.allclose(A, U @ np.diag(S) @ Vt)
True

(Fwiw, without full_matrics=False, error is same as mlx)

>>> U, S, Vt = np.linalg.svd(A)
>>> U @ np.diag(S) @ Vt
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 5 is different from 4)

adi-dhulipala avatar May 23 '24 05:05 adi-dhulipala

I think you can just do something like this:

U, S, V = mx.linalg.svd(A)
K = min(A.shape[0], A.shape[1])
Atilde = (U[:, :K] * S) @ V[:K, :]

We could add the slicing as an option like Numpy if it's useful.

Also I would recommend you rebase before making further progress to make it easier to resolve conflicts.

awni avatar May 23 '24 13:05 awni

Ah thanks Awni! Will use that ack on rebase.

adi-dhulipala avatar May 23 '24 18:05 adi-dhulipala

Small update: Got a local build that correctly computes pinv in most of the tests. Cleaning up some things and polishing up the code.

>>> import mlx.core as mx; mx.set_default_device(mx.cpu)
... A = mx.array([[1, 2, 1, 1, 9], [3, 4, 2, 2, 8], [2, 2, 1, 1, 4], [5, 6, 7, 2, 3.0] ])
... A_pinv = mx.linalg.pinv(A)
>>> A_pinv
array([[2.25371e-07, -1, 2, -7.40725e-08],
       [-0.408602, 1.2043, -1.44086, -0.064516],
       [0.363441, -0.48172, -0.0236563, 0.225806],
       [-0.346237, 0.873119, -0.894624, -0.0967741],
       [0.2, -0.2, 0.2, 1.00349e-08]], dtype=float32)

>>> A @ A_pinv @ A
array([[1, 2, 1, 1, 9],
       [3, 4, 2, 2, 8],
       [2, 2, 1, 0.999999, 4],
       [5, 6, 7, 2, 3]], dtype=float32)

>>> ans = A @ A_pinv @ A
>>> mx.allclose(A, A @ A_pinv @ A)
array(True, dtype=bool)

Turns out I was incorrectly relying on the computation array graph API instead of computing the actual matrix products (D'oh!). Now I have some code locally using lapack's mm func (such as sgemm) to compute the final pinv product.

Will update this PR soon

adi-dhulipala avatar May 27 '24 06:05 adi-dhulipala

Updated the PR. This PR is in a good enough shape for a review from @awni and other MLX folks. Thanks!

Perhaps there's one more thing to check (on my part) in the python tests. Will look into it. But in the meantime, this PR is still good for a review.

adi-dhulipala avatar May 27 '24 07:05 adi-dhulipala

Drats. I have another bug to fix. I updated the tests to catch it. Will look into and fix. Essentially, long rectangular matrices have a matmul dim mismatch -- which means I have made an error in the m, n, k calculations and/or slice selections or U/Vt

adi-dhulipala avatar May 27 '24 22:05 adi-dhulipala

Fixed the bug for rectangular matrices where M > N 🎉 Will publish commit soon

adi-dhulipala avatar May 28 '24 02:05 adi-dhulipala

This PR is ready for a review from Awni, Angelos and other MLX folks. Thanks!

adi-dhulipala avatar May 28 '24 02:05 adi-dhulipala

Hi @adhulipa . I think there shouldn't be a primitive for this operation. It can really just be an op in the linalg namespace.

angeloskath avatar May 29 '24 18:05 angeloskath

Hi @angeloskath ohh I see. I think I may have misinterpreted something in the thread here then. Particularly what @awni shared after you (@angeloskath) shared that comment earlier.

The op should be in C++ and then do a binding

Is it accurate to say that you meant this should in linalg.cpp where we add something relatively simple

auto outs = linalg.svd(x);
array U = outs[0]; 
array S = outs[1];
...///  etc.. 

return (V[:len(x)].T * 1/S) @ U.T // of course, in the cpp variant instead of the python-esque here

adi-dhulipala avatar May 29 '24 18:05 adi-dhulipala

Actually, @angeloskath do you mean to say that we don't need a primitive; but all the logic of calling linalg::svd()' ensuring the svd() call is eval()'d and then passing to lapack's sgemm()` function calls all should be in the linalg.cpp file (and namespace)?

It seems like the recommendation here is to keep the core logic intact but just not make this a primitive. Am I understanding that correctly?

adi-dhulipala avatar May 29 '24 18:05 adi-dhulipala

I think im starting to understand the motivation behind the c++ op sans primitive recommendation from Angelos. Pardon the roundabout way I needed to understand this 😅

The following change in linalg.cpp does pass the tests. Just checking a few more things before I can publish a new commit.

array pinv(const array& a, StreamOrDevice s /* = {} */) {
....

  const auto m = a.shape(-2);
  const auto n = a.shape(-1);
  const auto k = std::min(m, n);
  const auto rank = a.ndim();

  auto outs = linalg::svd(a, Device::cpu);
  auto U = outs[0];
  auto S = outs[1];
  auto Vt = outs[2];

....
....

  const auto U_slice = slice(U, {0, 0}, {m, k});
  const auto Vt_slice = slice(Vt, {0, 0}, {k, n});
  return matmul(matmul(transpose(Vt_slice), diag(1.0/S)), transpose(U_slice));;
}

adi-dhulipala avatar May 29 '24 21:05 adi-dhulipala

@angeloskath @awni -- question: do you folks feel like this is in a good shape for a review? Of course, no rush from my pov; just thought I'd check.

adi-dhulipala avatar Jun 04 '24 23:06 adi-dhulipala

hi @angeloskath @awni -- im curious if you think this PR is good for merge or review? Or perhaps there's another way to build this functionality? I am thankful for all the feedback so far and pointers on the road forward : )

adi-dhulipala avatar Jul 04 '24 02:07 adi-dhulipala

Thank you Angelos! No need to apologize re: delay : ) -- I appreciate the review + feedback. I will address and update the PR this week.

adi-dhulipala avatar Jul 11 '24 21:07 adi-dhulipala

Looks like a test failed. I will investigate

FAIL [0.006s]: test_pseudo_inverse (test_linalg.TestLinalg)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/distiller/project/python/tests/test_linalg.py", line 173, in test_pseudo_inverse
    self.assertTrue(mx.allclose(A @ A_plus @ A, A, rtol=0, atol=1e-6))
AssertionError: array(False, dtype=bool) is not true

adi-dhulipala avatar Jul 13 '24 00:07 adi-dhulipala

It looks like the failure is a matter of precision

For the square matrix case, it looks like 1e-5 is a passing tolerance limit whereas the existing 1e-6 level is causing a failure

>>>         A = mx.array([[1, 2, 3], [6, -5, 4], [-9, 8, 7]], dtype=mx.float32)
...         A_plus = mx.linalg.pinv(A, stream=mx.cpu)
>>> mx.allclose(A @ A_plus @ A, A, rtol=0, atol=1e-5)
array(True, dtype=bool)

>>> mx.allclose(A @ A_plus @ A, A, rtol=0, atol=1e-6)
array(False, dtype=bool)

Similarly, for a 2x3x3 matrix, it seems to be within 1e-3 but not within 1e-6

>>>         B = A - 100
...         AB = mx.stack([A, B])
...         pinvs = mx.linalg.pinv(AB, stream=mx.cpu)
>>>
>>> pinvs
array([[[0.313084, -0.0467289, -0.107477],
        [0.364486, -0.158878, -0.0654206],
        [-0.0140187, 0.121495, 0.0794393]],
       [[0.0915965, -0.0186531, -0.0762813],
        [0.169056, -0.134106, -0.0378953],
        [-0.274593, 0.154526, 0.11614]]], dtype=float32)

>>> for M, M_plus in zip(AB, pinvs):
...     print(M @ M_plus @ M)
...     print(mx.allclose(M @ M_plus @ M, M, rtol=0, atol=1e-3))
array([[1, 2, 3],
       [6, -5, 4],
       [-9, 8, 7]], dtype=float32)
array(True, dtype=bool)
array([[-98.9998, -97.9998, -96.9998],
       [-93.9997, -105, -95.9998],
       [-109, -91.9998, -92.9998]], dtype=float32)
array(True, dtype=bool)

adi-dhulipala avatar Jul 13 '24 00:07 adi-dhulipala

@angeloskath hey there Angelos, curious if you or others on MLX could come back to this PR sometime? Thanks!

adi-dhulipala avatar Aug 01 '24 23:08 adi-dhulipala