mlx
mlx copied to clipboard
[BUG] arithmetic operations with numpy arrays are not commutative
Describe the bug mx array being first operand in arithmetic op is being converted complex numbers
To Reproduce
Include code snippet
import numpy as np
import mlx.core as mx
mx.array(1) + np.array(2)
x = mx.array(1)
y = np.array(2)
x + y
# >>> array(3+0j, dtype=complex64)
x * y
# >>> array(2+0j, dtype=complex64)
y + x
# >>> 3
type(y + x)
# >>> <class 'numpy.int64'>
type(x + y)
# >>> <class 'mlx.core.array'>
Expected behavior Conversion to complex should not happen. Arithmetic ops should be commutative (ignoring type like ndarray or mx).
Desktop (please complete the following information):
- OS Version: [e.g. MacOS 14.1.2] MacOS 14.1.1
- Version [e.g. 0.7.0] 0.11
Additional context Add any other context about the problem here.
That looks like a bug 🤔
I am not sure if this is related - I am adding it here. Let me know if its not related, i will log another issue.
np.ndarray
of bfloat16 using ml_dtypes is being interpreted as complex64 by mlx.
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> x = np.array(1., dtype=bfloat16)
>>> import mlx.core as mx
>>> mx.array(x)
array(1+0j, dtype=complex64)
>>> x = np.array(1)
>>> mx.array(x)
array(1, dtype=int64)
>>> x = np.array(1., dtype=bfloat16)
>>> x.dtype
dtype(bfloat16)
>>> type(x.dtype)
<class 'numpy.dtype[bfloat16]'>
>>>
The bfloat
thing is a different issue. I will send a fix for the add shortly. Could put the bfloat problem in a separate issue as it might be harder to fix / require some changes in Nanobind (still investigating).