[Feature request] pytorch like operations with arrays
Nice to have some pytorch-like operations very useful with tensor (arrays) manipulations. e.g. tensor.unique().. which returns an array of unique numbers with or without counting them. or - tensor.unfold()... much more simple and straightforward way to apply rolling window to an array than np.lib.stride_tricks.as_strided()
Also mx.where() function is missing important functionality when it only receives the condition as an argument and returns list of array indices. Current implementation always requires three arguments. Both numpy and pytorch has such capability.
I only mean that have more compatible functions may simplify existing torch code migration at least for testing and verifications.
Cool, we follow the NumPy API so the where and unique are definitely on the table for feature enhancements. They are both tricky to expose in MLX since they produce outputs that have shapes based on the input data which a big reason we don't have them yet.
Unfold is something we could discuss if it's especially useful. We aren't opposed to including more than pure NumPy but usually want a really good reason to do so. (E.g. on going use cases that would really benefit from having it).
well.. from my view unford is mode efficient than np.lib.stride_tricks.as_strided() in two reasons:
- in my understanding, it does not create a separate array and thus doesn't consume additional memory, instead, it just changes the tensor data representation
- it is much more straightforward from a function call perspective... just three clear arguments.. look here https://pytorch.org/docs/stable/generated/torch.Tensor.unfold.html#torch.Tensor.unfold.. of course, this may be only a matter of habit... as, frankly, I do use torch much more than numpy :-)
technically nothing prevent us to make general manipulations using numpy or torch.. while mls use for modeling and training.
I am facing the same issue when I try to implement the Moe block for Mixtral model. My understanding is that mx.where doesn't support using only condition blocks for vectorized computation in the selected experts(so we have to explicit eval inds and using np.where instead).