mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[BUG] arithmetic operations with numpy arrays are not commutative

Open lkarthee opened this issue 9 months ago • 3 comments

Describe the bug mx array being first operand in arithmetic op is being converted complex numbers

To Reproduce

Include code snippet

import numpy as np
import mlx.core as mx

mx.array(1) + np.array(2)
x = mx.array(1)
y = np.array(2)
x + y 
# >>> array(3+0j, dtype=complex64)
x * y
# >>> array(2+0j, dtype=complex64)
y + x
# >>> 3
type(y + x)
# >>> <class 'numpy.int64'>
type(x + y)
# >>> <class 'mlx.core.array'>

Expected behavior Conversion to complex should not happen. Arithmetic ops should be commutative (ignoring type like ndarray or mx).

Desktop (please complete the following information):

  • OS Version: [e.g. MacOS 14.1.2] MacOS 14.1.1
  • Version [e.g. 0.7.0] 0.11

Additional context Add any other context about the problem here.

lkarthee avatar May 02 '24 03:05 lkarthee

That looks like a bug 🤔

awni avatar May 02 '24 03:05 awni

I am not sure if this is related - I am adding it here. Let me know if its not related, i will log another issue.

np.ndarray of bfloat16 using ml_dtypes is being interpreted as complex64 by mlx.

>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> x = np.array(1., dtype=bfloat16)
>>> import mlx.core as mx
>>> mx.array(x)
array(1+0j, dtype=complex64)
>>> x = np.array(1)
>>> mx.array(x)
array(1, dtype=int64)
>>> x = np.array(1., dtype=bfloat16)
>>> x.dtype
dtype(bfloat16)
>>> type(x.dtype)
<class 'numpy.dtype[bfloat16]'>
>>>

lkarthee avatar May 02 '24 04:05 lkarthee

The bfloat thing is a different issue. I will send a fix for the add shortly. Could put the bfloat problem in a separate issue as it might be harder to fix / require some changes in Nanobind (still investigating).

awni avatar May 02 '24 14:05 awni