array-api-compat
array-api-compat copied to clipboard
torch `take_along_axis` does not support negative indices
The standard states that take_along_axis must support negative indices as usual, but:
from array_api_compat import torch as xp
xp.take_along_axis(xp.asarray([1]), xp.asarray([-1]))
# RuntimeError: index -1 is out of bounds for dimension 0 with size 1
cf https://github.com/data-apis/array-api-compat/pull/361 for a fix and https://github.com/data-apis/array-api-tests/pull/397 for a test.
Would be helpful if you could test-drive gh-361 @mdhaber