mlx icon indicating copy to clipboard operation
mlx copied to clipboard

[Not BUG] Scatter C++ does not work for 1D arrays

Open cemlyn007 opened this issue 1 year ago • 3 comments

Describe the bug I could just have an incorrect understanding of how to use scatter? I am trying to update a 1D array.

To Reproduce

int main() {
  auto x = mlx::core::array({false, false, false}, mlx::core::bool_);
  auto i = mlx::core::array({0, 2}, mlx::core::int32);
  auto v = mlx::core::array({true, true}, mlx::core::bool_);
  if (x.ndim() != 1) {
    throw std::invalid_argument("x must be a 1D array");
  }
  if (i.ndim() != 1) {
    throw std::invalid_argument("i must be a 1D array");
  }
  if (v.ndim() != 1) {
    throw std::invalid_argument("v must be a 1D array");
  }
  x = mlx::core::scatter(x, i, v, 0);
}

Throws:

libc++abi: terminating due to uncaught exception of type std::invalid_argument: [scatter] Updates with 1 dimensions does not match the sum of the array and indices dimensions 2.

Expected behaviour A clear and concise description of what you expected to happen. To get x == [True, False, True]

In Python this works:

>>> x = mlx.core.array([False, False, False])
>>> i = mlx.core.array([0, 1])
>>> v = mlx.core.array([True, True])
>>> x.shape, i.shape, v.shape
((3,), (2,), (2,))
>>> x[i] = v
>>> x
array([True, True, False], dtype=bool)

Desktop (please complete the following information): -Mac ARM64

  • f5f18b704fb0a77f6bd56dbaeb687464dcb24bd5

cemlyn007 avatar Feb 28 '24 22:02 cemlyn007

It seems like this works but I am confused as to why the 1D case wouldn't work?

int main() {
  auto x = mlx::core::array({false, false, false}, mlx::core::bool_);
  auto i = mlx::core::array({0, 2}, mlx::core::int32);
  auto v = mlx::core::array({true, true}, mlx::core::bool_);
  if (x.ndim() != 1) {
    throw std::invalid_argument("x must be a 1D array");
  }
  if (i.ndim() != 1) {
    throw std::invalid_argument("i must be a 1D array");
  }
  if (v.ndim() != 1) {
    throw std::invalid_argument("v must be a 1D array");
  }
  v = mlx::core::expand_dims(v, 1);
  x = mlx::core::scatter(x, i, v, 0);
}

cemlyn007 avatar Feb 28 '24 22:02 cemlyn007

This isn't a bug. The C++ API is a little hard to use and undocumented, so sorry you ran into that issue.

For the C++ scatter API, the following must be true v.ndim == i.ndim + x.ndim. This is so we know which part of the update corresponds to each index.

I will leave this open and mark it as documentation. I don't think we will change the C++ scatter op. We may add an operator[] which will behave more like python index updates.

awni avatar Feb 29 '24 00:02 awni

Ah that's cool, feel free to close this and hopefully if others stumble into this then they will find this issue as a piece of documentation 💯 Thank you very much! I'm slowly dipping my toes trying to build a little Snake environment using MLX.

Hope you have a wonderful day!

cemlyn007 avatar Feb 29 '24 00:02 cemlyn007