mlx icon indicating copy to clipboard operation
mlx copied to clipboard

array.shape should return tuple not list for maximal compatibility with numpy

Open dastrobu opened this issue 1 year ago • 5 comments

I like the statement:

MLX has a Python API that closely follows NumPy.

One thing I noticed is that mx.array.shape returns a list:

mx.ones([2, 3]).shape
[2, 3]

while numpy returns a tuple:

np.ones([2, 3]).shape
(2, 3)

Given that tuples are immutable containers in python, returning a tuple makes more sense from my point of view.

I am not sure whether this was a design decision. But given that mlx is sill in an early phase, it might be worth considering changing the shape to a tuple type for maximal consistency with numpy.

dastrobu avatar Dec 27 '23 09:12 dastrobu

To be honest this wasn't a very intentional design decision. Pybind11 by default translates std::vector to Python list so that's why it is the way it is. I believe it's easy to change, but a likely breaking API change for downstream code..

I see some arguments in favor of one or the other. But I think it would be good to have some discussion around the pros and cons of each. So please chime in with your thoughts. Here's some of mine:

  • List is nice because it can be used almost everywhere a tuple can be (with the exception of maps) and it is easy to modify. A common idiom is to get the shape from an array, modify it, and use it to do another operation, like a reshape.

-Tuple is nice for perfect compatibility with NumPy. Beyond that what advantages does it ihave?

awni avatar Dec 27 '23 14:12 awni

Thanks for your comment.

I think it does not really matter as you can convert lists and tuples very easily and performance shouldn't really matter when dealing with shapes.

Having compatibility with numpy is the biggest advantage from my point of view, as you can easily migrate existing code from numpy arrays to mx arrays.

I can live with lists, I just thought it might have been overlooked as numpy compatibility seemed a design goal, which I really like, as this makes learning MLX easier.

Implementation wise this could be as easy as py::tuple(py::cast(a.shape())):

      // TODO, this makes a deep copy of the shape
      // implement alternatives to use reference
      // https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
      .def_property_readonly(
          "shape",
          [](const array& a) { return py::tuple(py::cast(a.shape())); },

Regarding the implementation details, I'm uncertain if the use of py::cast addresses the outstanding TODO in the code. Given my limited familiarity with pybind11, the term 'cast' suggests a non-copy operation to me. See also https://pybind11.readthedocs.io/en/stable/advanced/pycpp/object.html#casting-back-and-forth.

When there is agreement on the python type to use, I'd suggest to expose strides in addition to shape to python like numpy does.

dastrobu avatar Dec 27 '23 16:12 dastrobu

Here are additional benefits of employing an immutable type, such as a tuple, for representing shape:

  1. the empty tuple can be a singleton. That means for 0d arrays, the same reference to the empty tuple can be returned.
  2. it would allow to cache the python object representing the shape and return a reference to that cached object. Something that is only possible for immutable objects.

dastrobu avatar Dec 29 '23 09:12 dastrobu

Thanks for the input @dastrobu !

To me the main reason to make this change is consistency with numpy. It is an annoying change which is why I'm somewhat resistant, as we will have to find and update everything that uses shapes mutably. But so far that's probably not that much code. On the flip side, the longer we wait the harder it gets.

The efficiency benefit is a nice plus but probably low impact for the foreseeable future.

@angeloskath may have some thoughts on this. I know for example it was an issue with starting a backend for Keras.

awni avatar Dec 29 '23 20:12 awni

btw. pythons memoryview uses tuples for strides and shapes as well:

memoryview(np.ones(2)).shape
(2,)
memoryview(np.ones(2)).strides
(8,)

see also #320

dastrobu avatar Dec 30 '23 16:12 dastrobu

Closed by #591.

dastrobu avatar Feb 02 '24 13:02 dastrobu