array-api icon indicating copy to clipboard operation
array-api copied to clipboard

Proposal: add APIs for getting and setting elements via a list of indices (i.e., `take`, `put`, etc)

Open kgryte opened this issue 3 years ago • 35 comments

Proposal

Add APIs for getting and setting elements via a list of indices.

Motivation

Currently, the array API specification does not provide a direct means of extracting and setting a list of elements along an axis. Such operations are relatively common in NumPy usage either via "fancy indexing" or via explicit take and put APIs.

Two main arguments come to mind for supporting at least basic take and put APIs:

  1. Indexing does not currently support providing a list of indices to index into an array. The principal reason for not supporting fancy indexing stems from dynamic shapes and compatibility with accelerator libraries. However, use of fancy indexing is relatively common in NumPy and similar libraries where dynamically extracting rows/cols/values is possible and can be readily implemented.

  2. Currently, the output of a subset of APIs currently included in the standard cannot be readily consumed without manual workarounds if a specification-conforming library implemented only the APIs in the standard. For example,

    • argsort returns an array of indices. In NumPy, the output of this function can be consumed by put_along_axis and take_along_axis.
    • unique can return an array of indices if return_index is True.

Background

The following table summarizes library implementations of such APIs:

op NumPy CuPy Dask MXNet Torch TensorFlow
extracting elements along axis take take take take take/gather gather/numpy.take
setting elements along axis put put -- -- scatter scatter_nd/tensor_scatter_nd_update
extracting elements over matching 1d slices take_along_axis take_along_axis -- -- -- gather_nd/numpy.take_alongaxis
setting elements over matching 1d slices put_along_axis -- -- -- -- --

While most libraries implement some form of take, fewer implement other complementary APIs.

kgryte avatar May 06 '21 06:05 kgryte

Thanks @kgryte. A few initial thoughts:

  • There is also overlap between take/put and various scatter/gather functions in TensorFlow, PyTorch and MXNet. There's a whole host of those functions.
  • Is there really an issue with shape determinism? I'm probably missing something here, but isn't the output size along the given dimension equal to indices.size? And put is an inplace operation which doesn't change the shape.
  • If we do want to add these, we may consider putting them in the second version of the API. Just thinking that we should at some point stop making the API a permanently moving target.

rgommers avatar May 06 '21 14:05 rgommers

@rgommers Thanks for the comments.

  1. Correct. I've updated the table with Torch and TF scatter and gather methods.
  2. Correct me if I am wrong, but indices.size need not be fixed and could be data-dependent. For example, if extract the indices of unique elements from an array, the number of indices cannot necessarily be known AOT.
  3. Not opposed to delaying until V2 (2022).

kgryte avatar May 06 '21 16:05 kgryte

A natural question is if take is supported, is there any reason equivalent indexing shouldn't also be supported. Granted, take only represents a specific subset of general (NumPy) integer array indexing, where indexing is done on a single axis.

asmeurer avatar May 06 '21 16:05 asmeurer

@asmeurer I think that take would be an optional API; whereas indexing semantics should be universal.

kgryte avatar May 06 '21 16:05 kgryte

Correct me if I am wrong, but indices.size need not be fixed and could be data-dependent. For example, if extract the indices of unique elements from an array, the number of indices cannot necessarily be known AOT.

If the size of indices is variable, it's the function that produces indices that is data-dependent. take itself however is not. Compare with boolean indexing or nonzero, there the output size is in the range [0, x_input.size]; for take it's always x_input.size.

rgommers avatar May 07 '21 08:05 rgommers

@rgommers Correct; however, I still could imagine that data flows involving a take operation may still be problematic for AOT computational graphs. While the output size is indices.size, an array library may not be able to statically allocate memory for the output of the take operation. This said, accelerator libraries do manage to support similar APIs (e.g., scatter/gather), so probably no need to further belabor this.

kgryte avatar May 11 '21 00:05 kgryte

@asmeurer Re: integer array indexing. As mentioned during the previous call (03/06/2021), similar to boolean array indexing, could support a limited form of integer array indexing, where the integer array index is the sole index. Meaning, the spec would not condone mixing boolean with integer or require broadcasting semantics among the various indices.

kgryte avatar May 11 '21 01:05 kgryte

Cross-linking to a discussion regarding issues concerning out-of-bounds access in take APIs for accelerator libraries.

kgryte avatar May 11 '21 04:05 kgryte

In the ML use case, it is common to want to sample with replacement or shuffle a dataset. This is commonly done by sampling an integer array and using it to subset the dataset:

import numpy.array_api as xp

X = xp.asarray([[1, 2, 3, 4], [2, 3, 4, 5],
                [4, 5, 6, 10], [5, 6, 8, 20]], dtype=xp.float64)

sample_indices = xp.asarray([0, 0, 1, 3])

# Does not work
# X[sample_indices, :]

For libraries that need selection with integer arrays, a work around is to implement take:

def take(X, indices, *, axis):
    # Simple implementation that only works for axis in {0, 1}
    if axis == 0:
        selected = [X[i] for i in indices]
    else:  # axis == 1
        selected = [X[:, i] for i in indices]
    return xp.stack(selected, axis=axis)

take(X, sample_indices, axis=0)

Note that sampling with replacement can not be done with a boolean mask, because some rows may be selected twice.

thomasjpfan avatar Dec 07 '21 17:12 thomasjpfan

Hi @kmaehashi @asi1024 @emcastillo FYI. In a recent array API call we discussed about the proposed take/put APIs, and there were questions regarding how CuPy currently implements these functions, as there could be data/value dependency and people were wondering if we just have to pay the synchronization cost to ensure the behavior is correct. Could you help address? Thanks! (And sorry I dropped the ball here...)

leofang avatar Feb 01 '22 04:02 leofang

@asmeurer Re: integer array indexing. As mentioned during the previous call (03/06/2021), similar to boolean array indexing, could support a limited form of integer array indexing, where the integer array index is the sole index. Meaning, the spec would not condone mixing boolean with integer or require broadcasting semantics among the various indices.

+1 I think "array only" integer indexing would be quite well defined, and would not be problematic for accelerators. The main challenge with NumPy's implementation of "advanced indexing" is handling mixed integer/slice/boolean cases.

shoyer avatar Mar 10 '22 18:03 shoyer

Here is a summary of today's discussion:

  • Implementing take is fine, there's no problem for accelerators and all libraries listed above already have this API. Given that they all have it, there's no problem adding take to the standard right now.
  • The __getitem__ part of indexing is equivalent to take. However, as @asmeurer pointed out, it would be odd to add support for integer array indexing in __getitem__ but not in __setitem__. Hence we need to look at the latter.
  • put and __setitem__ are also equivalent - and more problematic, for multiple reasons:
    • as the table in the issue description shows, put isn't widely supported across libraries, and not with the same name either.
    • put is explicitly an in-place function in NumPy et al., which is a problem for JAX/TensorFlow. Having a better handle on the topic of mutability looks like a hard requirement before even considering an in-place function like put.
    • @oleksandr-pavlyk suggested adding a new out of place version of put to the standard. However, that's a new function that libraries don't yet have (actually some do under names like index_put, but it's a mixed bag). And it's not clear that this would be preferred in the long term; an inplace put that is guaranteed to raise when it crosses paths with a view may be better.

Given all that, the proposal is to only add take now, and revisit integer array indexing and put in the future.

rgommers avatar Mar 24 '22 20:03 rgommers

Something that I think was missed in today's discussion is that take and put aren't exactly the same as integer array indexing. Integer array indices operate on the axes of the array. take and put (at least in NumPy) operate on the flattened array.

>>> a = np.arange(9).reshape((3, 3)) + 10
>>> a[np.array([0, 2]), np.array([1, 2])]
array([11, 18])
>>> np.ravel_multi_index((np.array([0, 2]), np.array([1, 2])), (3, 3))
array([1, 8])
>>> np.take(a, np.ravel_multi_index((np.array([0, 2]), np.array([1, 2])), (3, 3)))
array([11, 18])

np.take also has an axis parameter but that's only equivalent to a single integer array index.

I'm not sure if there's an easy way within the array API to go from one to the other.

And I hope the the "integer array as the sole index" idea above was really meant to be "integer arrays as the sole indices". Just having a single integer array index means you can only index the first dimension of the array. This should also include integer scalars, as those are equivalent to 0-D arrays, unless we want to omit the "all integer array indices are broadcast together" rule.

I agree that NumPy's rules for mixing arrays with slices should not be included, especially the crazy rule about how it handles slices between integer array indices, which a design mistake in NumPy (slices around integer array indices isn't so bad, and can be useful, but also adds complexity to the indexing rules so I can see wanting to omit it).

asmeurer avatar Mar 24 '22 21:03 asmeurer

It's definitely possible (but not necessarily easy) to rewrite every call to np.take in terms of __getitem__ with integer arrays. For a library like Xarray, support for all integer indexing (especially with broadcasting) would be sufficient. So from my perspective, support for all integer indexing in __getitem__ and possibly also __setitem__ would the most useful functionality.

I would not be opposed to adding take if there is interest. It certainly is easier to construct calls to take, and knowing ahead of time that indexing is only going along a certain axis can sometimes allow for significant simplifications to indexing code. There are two alternatives we could consider for filling this same niche (easy integer based indexing along one dimension):

  1. Support for mixed array/slice indexing, like NumPy. But like Aaron says, this is too confusing for the API standard.
  2. We could include oindex, but this proposal never got entirely off the ground (beyond implementations in Xarray/Dask/Zarr).

If we do choose to include ake in the standard, the axis argument should be required. Slicing along flattened arrays is not very useful.

shoyer avatar Mar 24 '22 23:03 shoyer

It's definitely possible (but not necessarily easy) to rewrite every call to np.take in terms of getitem with integer arrays. For a library like Xarray, support for all integer indexing (especially with broadcasting) would be sufficient. So from my perspective, support for all integer indexing in getitem and possibly also setitem would the most useful functionality.

The suggestion here is to support take but defer support for indexing. So users of the array API would need to rewrite usages of __getitem__ to take, not the other way around.

Slicing along flattened arrays is not very useful.

I've never really used take myself, so I don't have the best context here, but isn't the flattened behavior there to match put, which doesn't have axis?

asmeurer avatar Mar 24 '22 23:03 asmeurer

What is the concern with supporting integer array indexing in __setitem__? Just the fact that it may not be implemented in otherwise compliant array libraries?

shoyer avatar Mar 25 '22 00:03 shoyer

What is the concern with supporting integer array indexing in __setitem__? Just the fact that it may not be implemented in otherwise compliant array libraries?

That, and also that it's non-deterministic when indices are not unique, as noted in the PyTorch and TF docs on scatter/scatter_nd.

rgommers avatar Mar 30 '22 11:03 rgommers

That, and also that it's non-deterministic when indices are not unique, as noted in the PyTorch and TF docs on scatter/scatter_nd.

I think we could probably safely leave this as undefined behavior?

shoyer avatar Mar 30 '22 15:03 shoyer

I think we could probably safely leave this as undefined behavior?

Yes, fair enough.

Let me add another concern though, probably the main one (copied from higher up, with a minor edit: put --> __setitem__): Having a better handle on the topic of mutability looks like a hard requirement before even considering an in-place function like __setitem__.

I think my preferred order of doing things here would be:

  1. Add take with 1-D integer array indices now (see gh-416)
  2. Tighten up mutability specification
  3. Add __getitem__ and __setitem__ (with n-D integer inputs, assuming that behavior aligns across libraries).

rgommers avatar Apr 07 '22 20:04 rgommers

Very interested in having at least a basic version of take to be incorporated into the standard.

Context: experimental version of more verbose indexing, see https://github.com/arogozhnikov/einops/issues/194 for details

arogozhnikov avatar Jun 27 '22 10:06 arogozhnikov

Thanks for the ping on this issue @arogozhnikov - and nice to see the experimental work on indexing in einops. I'd like to see gh-416 finished and merged in the coming days to indeed add take support with 1-D indices.

rgommers avatar Jul 07 '22 13:07 rgommers

Support for take has been merged, see gh-416.

rgommers avatar Nov 16 '22 20:11 rgommers

nit. we also have put_ in PyTorch (but not put...)

lezcano avatar Feb 15 '23 00:02 lezcano

I think if we were to introduce something like xp.put(x, indices, value) to the spec we can seemingly all agree on

  1. Only specifying a single array as the indices argument like we do with xp.take(), leaving other kinds of indices out-of-scope.

    e.g. for xp.put(x, indices, np.asarray([42, 7])) where x=xp.arange(5), indices=xp.asarray([1, 4]) would be supported, but the following equivalent arguments would be out-of-scope

    indices=(np.asarray(1), np.asarray(4))
    indices=(1, 4)
    indices=(1, np.asarray(4))
    indices=[1, 4]
    

    Array only indexing makes adoption easier and doesn't cause problems for accelerators.

  2. Keeping in-line with xp.take(), we should specify to only support indices as 1 dimensional.

    • Implicitly I'd be mandating that elements in indices relate to the index equivalent of the flattened equivalent of the input array, rather than any fancy broadcasting behaviours/etc..

      e.g. on the contrary, t.index_put_() use the shape of indices to specify multiple elements of the input array.

      >>> t = torch.as_tensor([[True, True]])
      >>> t.index_put_((torch.as_tensor([0]),), torch.as_tensor(False))
      tensor([[False, False]])
      
  3. Only specifying support for unique indices, e.g. indices=xp.asarray([0, 0]) would be out-of-scope. Consistent duplicate indices behaviour seems too niche and finicky to specify.

    • Interestingly PyTorch has the accumulate keyword for its t.put_()/t.index_put_() methods, where accumulate=False (default) leaves such behaviour unspecified, and accumulate=True puts the sum of the respective values.

The question areas then would be

  1. Should we support broadcasting value to the shape of indices?

    • np.put() broadcasts(?) the value to the indices, e.g.

      >>> a = np.asarray([True, True])
      >>> np.put(a, np.asarray([0, 1]), np.asarray(False))
      >>> a
      array([False, False])
      
    • On the contrary, torch.put_() requires the indices (index) to share the same shape as the value (source).

    As broadcasting is convenient and very common throughout the spec, IMO I'd specify value can be broadcasted to indices.

    Regardless I think we'd disallow broadcast-incompatible shapes, and value.size > indices.size scenarios.

  2. Should xp.put() return an array? What are the expectations for in-place and out-of-place behaviour?

    • The spec currently always(?) returns arrays for its functions, which seems a nice cadence to maintain.
    • Notably NumPy for its top-level np.put() currently only acts on the array in-place and does not return the modified array, whereas PyTorch has its put-like functions/methods return the modified array (some functions/methods also acting in-place).

    If we mandate xp.put() is to return the modified input array, we could leave in-place behaviour out-of-scope, or slap a copy keyword like we do for xp.asarray() and xp.reshape().

    Worth noting that the name put() also suggests in-place behaviour at this point.

honno avatar Apr 18 '23 09:04 honno

Some thoughts on @honno's points

  1. I think that's because all those objects are array_likes in NumPy. The same happens for any function that accepts an array in NumPy.
  2. SGTM but perhaps extending it to "indices of dimension at most 1". I believe PyTorch accepts wlog any contiguous array of any size, but I've never seen it used with anything but arrays of dim 0 or 1
  3. Checking for repeated values is indeed too costly. I think it should be left unspecified.
  4. Either SGTM
  5. In PyTorch we return an array for in-place ops, following the C++ convention. This is sometimes used to be able to chain in-place ops. I think it'd be good to mandate this all throughout the API.

lezcano avatar Apr 18 '23 10:04 lezcano

  1. you discussed broadcasting values to indices. What about broadcasting indices to values? (case of specific interest to me)
>>> a = np.asarray([[True, True]])
>>> np.put(a, np.asarray([0]), np.asarray([[False, True]]))
>>> a
array([[False,  True]])

arogozhnikov avatar Apr 18 '23 19:04 arogozhnikov

Broadcasting indices should give, by definition, repeated indices, which should invoke UB (which value is written to the index 0 in your example?).

lezcano avatar Apr 18 '23 20:04 lezcano

no, see my example above. values consist of one row, and index specifies that first row of values should be assigned to first row of result. I am not sure this is strictly the case of broadcasting, but that's a common thing to do.

Compare with:

matrix_n_by_n[[1, 2, 6]] = matrix_3_by_n

arogozhnikov avatar Apr 18 '23 21:04 arogozhnikov

We should clarify in the spec that behavior on out-of-bounds indices is unspecified. The take spec currently doesn't say anything about this (I'm assuming this is behavior we want since we already say this for basic integer indexing.

asmeurer avatar Apr 18 '23 23:04 asmeurer

I had a look at implementations in libraries, some updates on what's in the issue description:

  • PyTorch, in addition to scatter, has Tensor.put_, so a method, and the trailing underscore indicating in-place behavior. A tensor is also returned.
  • Dask still doesn't have put or any of the other similar functions like putmask or put_along_axis. There doesn't seem to be a blocker though, it's only that no one has done the work yet (e.g., see https://github.com/dask/dask/issues/3664 with a put_along_axis feature request).
  • JAX actually has it in its namespace (jax.numpy.put, but the implementation is:
def put(*args, **kwargs):
  raise NotImplementedError(
    "jax.numpy.put is not implemented because JAX arrays cannot be modified in-place. "
    "For functional approaches to updating array values, see jax.numpy.ndarray.at: "
    "https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.")

for similar reasons as it avoids other in-place APIs (xref design_topics/copies_views_and_mutation).

So I think we should consider the addition of put feasible in principle but blocked right now. The JAX issue is most difficult to resolve (can be done, but a lot of work still to deal with read-only views or similar), but the lack of API uniformity makes this a hard sell in general.

rgommers avatar Apr 19 '23 05:04 rgommers