mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[BUG] mlx crashes with msg - uncaught exception of type std::invalid_argument: [Scatter::eval_gpu] Does not support int64

Open lkarthee opened this issue 9 months ago • 1 comments

Describe the bug A clear and concise description of what the bug is.

To Reproduce

Include code snippet

import numpy as np
import mlx.core as mx
from keras.src.ops import core
indices = np.array([[1], [3], [4], [7]])
values = np.array([9, 10, 11, 12])
from keras.src import backend
backend.backend()
# >>> 'mlx'
x = core.scatter(indices, values, (8,))
x
# libc++abi: terminating due to uncaught exception of type std::invalid_argument: [Scatter::eval_gpu] Does not support int64
zsh: abort      python
# keras.ops.scatter for mlx backend
def scatter(indices, values, shape):
    indices = convert_to_tensor(indices)
    values = convert_to_tensor(values)
    zeros = mx.zeros(shape, dtype=values.dtype)
    indices = tuple(indices[..., i] for i in range(indices.shape[-1]))
    zeros = zeros.at[indices].add(values)

    return zeros

Expected behavior Mlx should not crash - it should throw an exception or error.

Desktop (please complete the following information):

  • OS Version: [e.g. MacOS 14.1.2] 14.1
  • Version [e.g. 0.7.0] '0.12.2'

Additional context Add any other context about the problem here.

lkarthee avatar May 04 '24 06:05 lkarthee

If you are just asking for a catchable exception then #1077 should close this. We would like to eventually allow int64 and other 8 byte types to work with scatter, but that is more involved.

awni avatar May 05 '24 02:05 awni

Thank you Awni. some observations:

  • crash message is so confusing - does not say where the problem is with the array or indices or values. Can we improve it by mentioning workaround in error message ?
  • scatter ops can use cpu device for int64 and uint64 ?
  • adding a note about supported dtypes and devices cpu, gpu in mlx.core.array.at is helpful .
  • are there any other ops which are not supported on gpu and run on cpu ?
zeros = mx.zeros(shape, dtype=values.dtype) 
zeros = zeros.at[indices].add(values) 

i tried this and it does not work as add does not take device kw_arg:

if zeros.dtype in [mx.int64, mx.uint64] and mx.get_default_device == mx.DeviceType.gpu :
  device = mx.Device(type=mx.DeviceType.cpu)
  zeros = zeros.at[indices].add(values, device=device) 
else:
  zeros = zeros.at[indices].add(values)

It would be helpful if mlx can fallback to cpu for scatter ops which are not supported on gpu or allow device kw_arg for all scatter ops.

Additional ops which are impacted by this bug:

  • mx.cumsum
  • mx.cumprod
  • mx.diag

lkarthee avatar May 05 '24 03:05 lkarthee

crash message is so confusing - does not say where the problem is with the array or indices or values. Can we improve it by mentioning workaround in error message ?

I improved the message in #1077. The problem is with the values.

scatter ops can use cpu device for int64 and uint64 ?

We prefer not to silently route to the CPU for ops without a GPU back-end. You can do this in the API by changing the default stream to the CPU before calling the scatter when the dytpe is int64/uint64.

are there any other ops which are not supported on gpu and run on cpu ?

Just a few. FFT and some of the lapack ops (QR / Inverse). Metal support for FFT is coming soon in #981 .

i tried this and it does not work as add does not take device kw_arg:

You can use a context manager. For most free ops stream kwarg also works. E.g.

v = mx.array([1, 2, 3])
u = mx.array([1, 2])
idx = mx.array([0, 1])

with mx.stream(mx.cpu):
    out = v.at[idx].add(u)

awni avatar May 14 '24 00:05 awni

Thank you @awni for the fix.

lkarthee avatar May 14 '24 01:05 lkarthee