mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Add Gemm for complex64

Open colivarese opened this issue 2 years ago • 1 comments

Not sure if this is an issue of mine, but I'm unable to use mx.matmul or use the @ operator on two mx.arrays, it throws the following error:

` libc++abi: terminating due to uncaught exception of type std::runtime_error: [metal::Device] Unable to load kernel gemv_t_complex64_bm8_bn8_tm4_tn1

zsh: abort `

This is my whole code:

`def make_nplr_hippo(N:int): nhippo = -1 * make_hippo(N) p = mx.array(0.5 * np.sqrt(2* np.arange(1, N+1) + 1)) q = 2*p p = mx.transpose(mx.expand_dims(p, axis=0)) q = mx.expand_dims(q, axis=0) S = nhippo + p_ * q_ lambda_, V = np.linalg.eig(S) return lambda_, p, q, V

n = 3 lambda_, p, q, V = _make_nplr_hippo(n) Vc = mx.transpose(mx.array(np.array(V, dtype=complex))) p = mx.matmul(mx.array(Vc), p) print(p)

` Im using Mac M1 on VsCode Python 3.11.3 Package Version


mlx 0.0.4 numpy 1.26.2 pip 23.3.1 pybind11 2.11.1 pybind11-global 2.11.1 setuptools 67.6.1

colivarese avatar Dec 12 '23 05:12 colivarese

You're not doing anything wrong, we don't yet have a complex64 metal gemm implemented (or complex64 cpu gemm for that matter).

@jagrit06 any thoughts on this? How hard would it be to get the complex type working?

We could easily add cblas_cgemm for the CPU backend.

For the Metal backend another option might be MPS. It has a complex type, maybe it works with their gemm kernel.

awni avatar Dec 12 '23 05:12 awni