array-api-compat icon indicating copy to clipboard operation
array-api-compat copied to clipboard

`torch` support indexing with negative step

Open mdhaber opened this issue 1 year ago • 9 comments

The array API standard seems to support negative step.

The basic slice syntax is i:j:k where i is the starting index, j is the stopping index, and k is the step (k != 0).

But array-api-compat.torch tensors do not:

from array_api_compat import torch
x = torch.arange(10)
x[::-1]  # ValueError: step must be greater than zero

Adding support for negative step would be appreciated! (In the meantime, I can use flip.) Thanks for considering it.

mdhaber avatar Jun 08 '24 22:06 mdhaber

It would have to be via a wrapper function, since we don't wrap the tensor objects. Does torch have a function that reverses a tensor? Maybe it should be done by manipulating the strides?

asmeurer avatar Jun 08 '24 23:06 asmeurer

It doesn't look like it. One of the most recent requests for this feature is pytorch/pytorch#59786, and it links to one of the very old requests, pytorch/pytorch#229. Looks like it's just not implemented. flip is the substitute, but it makes a copy.

mdhaber avatar Jun 08 '24 23:06 mdhaber

Hmm. If you try to do this with strides, you get an error:

>>> a = torch.arange(10)
>>> torch.as_strided(a, a.shape, (-8,))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: as_strided: Negative strides are not supported at the moment, got strides: [-8]

So I think torch just fundamentally doesn't support reversed views right now. The best we can do is a helper to translate a slice into a transformation with flip, which would be a copy as you noted.

What generality of slices do you need support for? Steps less than -1? Start and stop? Slices in multidimensional indices?

asmeurer avatar Jun 10 '24 22:06 asmeurer

For the time being I've used flip and probably won't change it by the time a patch is available in SciPy. So from that perspective I don't need anything, and I just thought I should report the issue. But in the context I encountered this issue, it was just [::-1].

mdhaber avatar Jun 10 '24 22:06 mdhaber

This reminds me of how we've discussed getting around the fact that JAX can't mutate arrays - we've discussed a function for mutating elements at specified indices if possible and copying otherwise (maybe with the JAX .at syntax). Is array_api_compat the place for that, or does each project need to implement its own?

ISTM we might want a similar thing here, because it might not be OK to copy if the user is expecting a view.

Another possibility is just re-raising, explaining why array_api_compat can't implement negative steps and recommending flip if copies are OK.

mdhaber avatar Jun 10 '24 22:06 mdhaber

array-api-compat generally isn't the place to implement new APIs that aren't in the standard (see https://data-apis.org/array-api-compat/#scope).

However, something that could be in scope for array-api-compat is helper functionality to workaround how different libraries handle copies vs. views. I don't know what that would look like exactly, but if you have any proposals of things that could help I'm open to hearing them. We should open a new issue to discuss this.

asmeurer avatar Jun 11 '24 19:06 asmeurer

At any rate, this issue makes me realize that a function that converts a slice into a strides and offset could be a generally useful thing. I might implement it in ndindex at some point https://github.com/Quansight-Labs/ndindex/issues/180.

asmeurer avatar Jun 11 '24 22:06 asmeurer

We should also document this as a known difference with pytorch.

asmeurer avatar Nov 15 '24 00:11 asmeurer

https://data-apis.org/array-api-compat/supported-array-libraries.html#pytorch now mentions that

Slices do not support negative steps.

And from previous interactions with pytorch devs, my impression was that the assumption that slice steps are non-negative is baked too deep into the torch core to really hope it'll get fixed in torch itself.

ev-br avatar Nov 15 '24 11:11 ev-br