mlx
mlx copied to clipboard
array.shape should return tuple not list for maximal compatibility with numpy
I like the statement:
MLX has a Python API that closely follows NumPy.
One thing I noticed is that mx.array.shape returns a list:
mx.ones([2, 3]).shape
[2, 3]
while numpy returns a tuple:
np.ones([2, 3]).shape
(2, 3)
Given that tuples are immutable containers in python, returning a tuple makes more sense from my point of view.
I am not sure whether this was a design decision. But given that mlx is sill in an early phase, it might be worth considering changing the shape to a tuple type for maximal consistency with numpy.
To be honest this wasn't a very intentional design decision. Pybind11 by default translates std::vector to Python list so that's why it is the way it is. I believe it's easy to change, but a likely breaking API change for downstream code..
I see some arguments in favor of one or the other. But I think it would be good to have some discussion around the pros and cons of each. So please chime in with your thoughts. Here's some of mine:
- List is nice because it can be used almost everywhere a tuple can be (with the exception of maps) and it is easy to modify. A common idiom is to get the shape from an array, modify it, and use it to do another operation, like a reshape.
-Tuple is nice for perfect compatibility with NumPy. Beyond that what advantages does it ihave?
Thanks for your comment.
I think it does not really matter as you can convert lists and tuples very easily and performance shouldn't really matter when dealing with shapes.
Having compatibility with numpy is the biggest advantage from my point of view, as you can easily migrate existing code from numpy arrays to mx arrays.
I can live with lists, I just thought it might have been overlooked as numpy compatibility seemed a design goal, which I really like, as this makes learning MLX easier.
Implementation wise this could be as easy as py::tuple(py::cast(a.shape())):
// TODO, this makes a deep copy of the shape
// implement alternatives to use reference
// https://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html
.def_property_readonly(
"shape",
[](const array& a) { return py::tuple(py::cast(a.shape())); },
Regarding the implementation details, I'm uncertain if the use of py::cast addresses the outstanding TODO in the code. Given my limited familiarity with pybind11, the term 'cast' suggests a non-copy operation to me. See also https://pybind11.readthedocs.io/en/stable/advanced/pycpp/object.html#casting-back-and-forth.
When there is agreement on the python type to use, I'd suggest to expose strides in addition to shape to python like numpy does.
Here are additional benefits of employing an immutable type, such as a tuple, for representing shape:
- the empty tuple can be a singleton. That means for 0d arrays, the same reference to the empty tuple can be returned.
- it would allow to cache the python object representing the shape and return a reference to that cached object. Something that is only possible for immutable objects.
Thanks for the input @dastrobu !
To me the main reason to make this change is consistency with numpy. It is an annoying change which is why I'm somewhat resistant, as we will have to find and update everything that uses shapes mutably. But so far that's probably not that much code. On the flip side, the longer we wait the harder it gets.
The efficiency benefit is a nice plus but probably low impact for the foreseeable future.
@angeloskath may have some thoughts on this. I know for example it was an issue with starting a backend for Keras.
btw. pythons memoryview uses tuples for strides and shapes as well:
memoryview(np.ones(2)).shape
(2,)
memoryview(np.ones(2)).strides
(8,)
see also #320
Closed by #591.