[Not BUG] Scatter C++ does not work for 1D arrays
Describe the bug I could just have an incorrect understanding of how to use scatter? I am trying to update a 1D array.
To Reproduce
int main() {
auto x = mlx::core::array({false, false, false}, mlx::core::bool_);
auto i = mlx::core::array({0, 2}, mlx::core::int32);
auto v = mlx::core::array({true, true}, mlx::core::bool_);
if (x.ndim() != 1) {
throw std::invalid_argument("x must be a 1D array");
}
if (i.ndim() != 1) {
throw std::invalid_argument("i must be a 1D array");
}
if (v.ndim() != 1) {
throw std::invalid_argument("v must be a 1D array");
}
x = mlx::core::scatter(x, i, v, 0);
}
Throws:
libc++abi: terminating due to uncaught exception of type std::invalid_argument: [scatter] Updates with 1 dimensions does not match the sum of the array and indices dimensions 2.
Expected behaviour A clear and concise description of what you expected to happen. To get x == [True, False, True]
In Python this works:
>>> x = mlx.core.array([False, False, False])
>>> i = mlx.core.array([0, 1])
>>> v = mlx.core.array([True, True])
>>> x.shape, i.shape, v.shape
((3,), (2,), (2,))
>>> x[i] = v
>>> x
array([True, True, False], dtype=bool)
Desktop (please complete the following information): -Mac ARM64
- f5f18b704fb0a77f6bd56dbaeb687464dcb24bd5
It seems like this works but I am confused as to why the 1D case wouldn't work?
int main() {
auto x = mlx::core::array({false, false, false}, mlx::core::bool_);
auto i = mlx::core::array({0, 2}, mlx::core::int32);
auto v = mlx::core::array({true, true}, mlx::core::bool_);
if (x.ndim() != 1) {
throw std::invalid_argument("x must be a 1D array");
}
if (i.ndim() != 1) {
throw std::invalid_argument("i must be a 1D array");
}
if (v.ndim() != 1) {
throw std::invalid_argument("v must be a 1D array");
}
v = mlx::core::expand_dims(v, 1);
x = mlx::core::scatter(x, i, v, 0);
}
This isn't a bug. The C++ API is a little hard to use and undocumented, so sorry you ran into that issue.
For the C++ scatter API, the following must be true v.ndim == i.ndim + x.ndim. This is so we know which part of the update corresponds to each index.
I will leave this open and mark it as documentation. I don't think we will change the C++ scatter op. We may add an operator[] which will behave more like python index updates.
Ah that's cool, feel free to close this and hopefully if others stumble into this then they will find this issue as a piece of documentation 💯 Thank you very much! I'm slowly dipping my toes trying to build a little Snake environment using MLX.
Hope you have a wonderful day!