mlx
mlx copied to clipboard
[BUG] `np.ndarray` of bfloat16 using ml_dtypes is being interpreted as complex64
Describe the bug
np.ndarray
of bfloat16 using ml_dtypes is being interpreted as complex64 by mlx.
To Reproduce
>>> from ml_dtypes import bfloat16
>>> import numpy as np
>>> x = np.array(1., dtype=bfloat16)
>>> import mlx.core as mx
>>> mx.array(x)
array(1+0j, dtype=complex64)
>>> x = np.array(1)
>>> mx.array(x)
array(1, dtype=int64)
>>> x = np.array(1., dtype=bfloat16)
>>> x.dtype
dtype(bfloat16)
>>> type(x.dtype)
<class 'numpy.dtype[bfloat16]'>
>>>
Expected behavior Conversion to complex should not happen. Should remain as bfloat16
Desktop (please complete the following information):
- OS Version: [e.g. MacOS 14.1.2] MacOS 14.1.1
- Version [e.g. 0.7.0] 0.11
Additional context Originally posted by @lkarthee in https://github.com/ml-explore/mlx/issues/1066#issuecomment-2089573368