array-api icon indicating copy to clipboard operation
array-api copied to clipboard

Add `argwhere`

Open jakirkham opened this issue 2 years ago • 4 comments

Currently nonzero is defined in the specification, but argwhere is not. PR ( https://github.com/data-apis/array-api/pull/23 ) mentions argwhere, but it is not currently included.

There was some discussion starting with comment ( https://github.com/data-apis/array-api/pull/23#issuecomment-859752792 ) about whether argwhere should be included. Though argwhere was viewed as duplicative, which is understandable.

However for libraries that have arrays with unknown shape elements (like Dask). nonzero can run into some difficulties as the different arrays may not be concatenated (at least not without an explicit override). argwhere avoids this issue as it already returns a single array so there is no need to concatenate and the array can be split apart easily to satisfy the nonzero case ( for example: https://github.com/dask/dask/pull/2539 ). As a result argwhere becomes practically more useful to handle these unknown shape cases.

Would like to discuss adding argwhere to the spec to handle this need

jakirkham avatar Jun 06 '22 20:06 jakirkham

Let me copy my comment from gh-23 here, because I think it's still relevant/helpful:

start of old comment

Argh, why do we have both nonzero and argwhere! They're almost identical. I actually like argwhere better, but because of the name it's used far less than nonzero I believe.

PyTorch has an unbind method that allows to efficiently go from 2-D tensor to tuple of 1-D tensors. Which, when applied to the output of argwhere, would be an O(1) method that makes nonzero unnecessary. And it would resolve the pain point PyTorch is having with incompatible nonzero definition.

This may be worth considering. The name is a little annoying though. Especially because argwhere isn't a counterpart to where, but the same thing as the discouraged one-element form of where.

>>> x = np.array([0, 4, 7, 0])
>>> np.nonzero(x)
(array([1, 2]),)
>>> np.where(x)
(array([1, 2]),)
>>> np.argwhere(x)
array([[1],
       [2]])
>>> x[np.nonzero(x)]
array([4, 7])
>>> x[np.where(x)]
array([4, 7])
>>> x[np.argwhere(x)]  # this needs a squeeze() to roundtrip for 1-D input
array([[4],
       [7]])

end of old comment

argwhere returns (from docs): index_array(N, a.ndim) ndarray

Let me add a higher-dimensional example as well:

>>> x.shape
(4, 3, 2)
>>> x[np.nonzero(x)]
array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
       18, 19, 20, 21, 22, 23, 24])
>>> x[np.argwhere(x)].shape
(24, 3, 3, 2)

There's more complex ways to use the indices returned from argwhere, but I don't see a straightforward one (I may be missing it, still sleepy). So how useful is argwhere to do "unknown-shape nonzero"?

rgommers avatar Jun 08 '22 08:06 rgommers

However for libraries that have arrays with unknown shape elements (like Dask). nonzero can run into some difficulties as the different arrays may not be concatenated (at least not without an explicit override). argwhere avoids this issue as it already returns a single array so there is no need to concatenate and the array can be split apart easily to satisfy the nonzero case ( for example: dask/dask#2539 ). As a result argwhere becomes practically more useful to handle these unknown shape cases.

Can you give a concrete example of where argwhere lets you do something in dask that nonzero does not?

shoyer avatar Jun 23 '22 01:06 shoyer

It's trivial to implement nonzero from argwhere since it is just splitting out the results. In fact all of the *nonzero style functions in Dask are implemented with argwhere. Going the opposite direction unfortunately requires a bit of a leap of faith as one needs to call concatenate and tell it to ignore checking that the shapes match (in Dask this would be allow_unknown_chunksizes=True). Of course for nonzero this is fine as we know it should match (though that's not generally the case hence the option). Where it gets trickier is telling concatenate to allow this behavior may result in dealing with Dask Arrays in a non-spec conforming way (the option does not exist for other array implementations), which could make it hard to write Array API code that would accommodate Dask in these cases. This can come up in other places code depends on shape checking

jakirkham avatar Jun 23 '22 17:06 jakirkham

Right, I guess my question is why getting all indices together into a single array matters.

shoyer avatar Jun 23 '22 17:06 shoyer

As this proposal is without a champion, I'll go ahead and close.

kgryte avatar Jun 29 '23 08:06 kgryte