mlx
mlx copied to clipboard
[BUG] mlx crashes with msg - uncaught exception of type std::invalid_argument: [Scatter::eval_gpu] Does not support int64
Describe the bug A clear and concise description of what the bug is.
To Reproduce
Include code snippet
import numpy as np
import mlx.core as mx
from keras.src.ops import core
indices = np.array([[1], [3], [4], [7]])
values = np.array([9, 10, 11, 12])
from keras.src import backend
backend.backend()
# >>> 'mlx'
x = core.scatter(indices, values, (8,))
x
# libc++abi: terminating due to uncaught exception of type std::invalid_argument: [Scatter::eval_gpu] Does not support int64
zsh: abort python
# keras.ops.scatter for mlx backend
def scatter(indices, values, shape):
indices = convert_to_tensor(indices)
values = convert_to_tensor(values)
zeros = mx.zeros(shape, dtype=values.dtype)
indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
zeros = zeros.at[indices].add(values)
return zeros
Expected behavior Mlx should not crash - it should throw an exception or error.
Desktop (please complete the following information):
- OS Version: [e.g. MacOS 14.1.2] 14.1
- Version [e.g. 0.7.0] '0.12.2'
Additional context Add any other context about the problem here.
If you are just asking for a catchable exception then #1077 should close this. We would like to eventually allow int64 and other 8 byte types to work with scatter, but that is more involved.
Thank you Awni. some observations:
- crash message is so confusing - does not say where the problem is with the array or indices or values. Can we improve it by mentioning workaround in error message ?
- scatter ops can use cpu device for int64 and uint64 ?
- adding a note about supported dtypes and devices cpu, gpu in mlx.core.array.at is helpful .
- are there any other ops which are not supported on gpu and run on cpu ?
zeros = mx.zeros(shape, dtype=values.dtype)
zeros = zeros.at[indices].add(values)
i tried this and it does not work as add does not take device
kw_arg:
if zeros.dtype in [mx.int64, mx.uint64] and mx.get_default_device == mx.DeviceType.gpu :
device = mx.Device(type=mx.DeviceType.cpu)
zeros = zeros.at[indices].add(values, device=device)
else:
zeros = zeros.at[indices].add(values)
It would be helpful if mlx can fallback to cpu for scatter ops which are not supported on gpu or allow device kw_arg for all scatter ops.
Additional ops which are impacted by this bug:
- mx.cumsum
- mx.cumprod
- mx.diag
crash message is so confusing - does not say where the problem is with the array or indices or values. Can we improve it by mentioning workaround in error message ?
I improved the message in #1077. The problem is with the values.
scatter ops can use cpu device for int64 and uint64 ?
We prefer not to silently route to the CPU for ops without a GPU back-end. You can do this in the API by changing the default stream to the CPU before calling the scatter when the dytpe is int64/uint64.
are there any other ops which are not supported on gpu and run on cpu ?
Just a few. FFT and some of the lapack ops (QR / Inverse). Metal support for FFT is coming soon in #981 .
i tried this and it does not work as add does not take device kw_arg:
You can use a context manager. For most free ops stream
kwarg also works. E.g.
v = mx.array([1, 2, 3])
u = mx.array([1, 2])
idx = mx.array([0, 1])
with mx.stream(mx.cpu):
out = v.at[idx].add(u)
Thank you @awni for the fix.