mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[Feature] Metal inverse (`mx.linalg.inv`)

Open petertsoi opened this issue 1 year ago • 9 comments

Describe the bug When trying to invert a small 3x3 (camera intrinsics matrix), mlx crashes.

To Reproduce

import mlx.core as mx
intrinsics = mx.array([[1162.38, 0, 618.25], [0, 1156.83, 410.75], [0, 0, 1]])
mx.linalg.inv(intrinsics)

libc++abi: terminating due to uncaught exception of type std::runtime_error: [Inverse::eval_gpu] Metal inversion NYI.
zsh: abort      python

Expected behavior Works properly in numpy

Desktop (please complete the following information):

  • macOS 14.4.1
  • mlx: 0.15.1
  • python: 3.11.9

petertsoi avatar Jun 27 '24 00:06 petertsoi

Also not working in mlx-swift, which is where I'm using it from but reproduced in python so filed it here. The CPU backend appears to work though.

petertsoi avatar Jun 27 '24 00:06 petertsoi

Yes, this isn't a bug, the GPU back-end is not yet implemented. It's most likely going to take some time before we have GPU support for matrix inversion. I changed this to be a feature req rather than a bug, and we can leave the issue open.

awni avatar Jun 27 '24 00:06 awni

My recommendation is to use the CPU for now. You can do something like:

out = mx.llinalg.inv(x, stream=mx.cpu)

Just for that operation.

awni avatar Jun 27 '24 00:06 awni

By the way if all you want to do is 3x3 matrix inversion it is way faster to write it explicitly and compile it with mlx. The inversion would be as simple as the following:

import mlx.core as mx


@partial(mx.compile, shapeless=True)
def _inverse_3x3(a11, a12, a13, a21, a22, a23, a31, a32, a33):
    det = (
        a11 * a22 * a33
        + a12 * a23 * a31
        + a13 * a21 * a32
        - a11 * a23 * a32
        - a12 * a21 * a33
        - a13 * a22 * a31
    )
    c11 = (a22 * a33 - a23 * a32) / det
    c12 = (a13 * a32 - a12 * a33) / det
    c13 = (a12 * a23 - a13 * a22) / det
    c21 = (a23 * a31 - a21 * a33) / det
    c22 = (a11 * a33 - a13 * a31) / det
    c23 = (a13 * a21 - a11 * a23) / det
    c31 = (a21 * a32 - a22 * a31) / det
    c32 = (a12 * a31 - a11 * a32) / det
    c33 = (a11 * a22 - a12 * a21) / det
    return c11, c12, c13, c21, c22, c23, c31, c32, c33


def inverse_3x3(A):
    shape = A.shape
    return mx.concatenate(
        _inverse_3x3(*mx.split(A.reshape(*shape[:-2], -1), 9, -1)), -1
    ).reshape(shape)

For inverting thousands of 3x3 matrices the improvement over CPU is pretty great on my puny M2 Air:

Batch | linalg.inv | inverse_3x3
------+------------+------------
1     |       0.04 |       1.0
16    |       0.1  |       1.0
256   |       1.8  |       1.0
1024  |       7.5  |       1.1
8192  |      59.1  |       1.6
32768 |     243.3  |       3.8

For a single matrix obviously using the GPU is overkill but if you want to do 3x3 matmuls for instance writing them out explicitly like I did above may be significantly faster, same goes for triangle intersection math etc.

angeloskath avatar Jun 27 '24 18:06 angeloskath

Hi @awni, any update on when matrix inversion would be available on GPU? It would be extremely helpful for a lot of applications. For example, using Gauss-Newton second order optimizers (which are way more efficient than first order GD or ADAM) requires matrix inversion of the Jacobian matrix.

Thanks a lot.

kyrollosyanny avatar Feb 13 '25 01:02 kyrollosyanny

No update sorry. It's available on the CPU for now, use e.g. stream=mx.cpu

awni avatar Feb 13 '25 01:02 awni

Interestingly, it doesn't give me any errors if I do this mx.linalg.inv(JtJ,stream=mx.gpu) but maybe it is still not using the gpu?

kyrollosyanny avatar Feb 13 '25 20:02 kyrollosyanny

Maybe you are using an old version of MLX:

>>> mx.linalg.inv(mx.ones((2, 2)))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: [linalg::inv] This op is not yet supported on the GPU. Explicitly pass a CPU stream to run it.
>>> 

awni avatar Feb 13 '25 20:02 awni

Yes, just updated and gave me the error you mentioned. Only works with cpu stream for now. Thanks for the help.

kyrollosyanny avatar Feb 13 '25 20:02 kyrollosyanny