array-api
array-api copied to clipboard
Add `argwhere`
Currently nonzero
is defined in the specification, but argwhere
is not. PR ( https://github.com/data-apis/array-api/pull/23 ) mentions argwhere
, but it is not currently included.
There was some discussion starting with comment ( https://github.com/data-apis/array-api/pull/23#issuecomment-859752792 ) about whether argwhere
should be included. Though argwhere
was viewed as duplicative, which is understandable.
However for libraries that have arrays with unknown shape elements (like Dask). nonzero
can run into some difficulties as the different arrays may not be concatenate
d (at least not without an explicit override). argwhere
avoids this issue as it already returns a single array so there is no need to concatenate
and the array can be split apart easily to satisfy the nonzero
case ( for example: https://github.com/dask/dask/pull/2539 ). As a result argwhere
becomes practically more useful to handle these unknown shape cases.
Would like to discuss adding argwhere
to the spec to handle this need
Let me copy my comment from gh-23 here, because I think it's still relevant/helpful:
start of old comment
Argh, why do we have both nonzero
and argwhere
! They're almost identical. I actually like argwhere
better, but because of the name it's used far less than nonzero
I believe.
PyTorch has an unbind
method that allows to efficiently go from 2-D tensor to tuple of 1-D tensors. Which, when applied to the output of argwhere
, would be an O(1) method that makes nonzero
unnecessary. And it would resolve the pain point PyTorch is having with incompatible nonzero
definition.
This may be worth considering. The name is a little annoying though. Especially because argwhere
isn't a counterpart to where
, but the same thing as the discouraged one-element form of where
.
>>> x = np.array([0, 4, 7, 0])
>>> np.nonzero(x)
(array([1, 2]),)
>>> np.where(x)
(array([1, 2]),)
>>> np.argwhere(x)
array([[1],
[2]])
>>> x[np.nonzero(x)]
array([4, 7])
>>> x[np.where(x)]
array([4, 7])
>>> x[np.argwhere(x)] # this needs a squeeze() to roundtrip for 1-D input
array([[4],
[7]])
end of old comment
argwhere
returns (from docs): index_array(N, a.ndim) ndarray
Let me add a higher-dimensional example as well:
>>> x.shape
(4, 3, 2)
>>> x[np.nonzero(x)]
array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24])
>>> x[np.argwhere(x)].shape
(24, 3, 3, 2)
There's more complex ways to use the indices returned from argwhere
, but I don't see a straightforward one (I may be missing it, still sleepy). So how useful is argwhere
to do "unknown-shape nonzero
"?
However for libraries that have arrays with unknown shape elements (like Dask).
nonzero
can run into some difficulties as the different arrays may not beconcatenate
d (at least not without an explicit override).argwhere
avoids this issue as it already returns a single array so there is no need toconcatenate
and the array can be split apart easily to satisfy thenonzero
case ( for example: dask/dask#2539 ). As a resultargwhere
becomes practically more useful to handle these unknown shape cases.
Can you give a concrete example of where argwhere
lets you do something in dask that nonzero
does not?
It's trivial to implement nonzero
from argwhere
since it is just splitting out the results. In fact all of the *nonzero
style functions in Dask are implemented with argwhere
. Going the opposite direction unfortunately requires a bit of a leap of faith as one needs to call concatenate
and tell it to ignore checking that the shapes match (in Dask this would be allow_unknown_chunksizes=True
). Of course for nonzero
this is fine as we know it should match (though that's not generally the case hence the option). Where it gets trickier is telling concatenate
to allow this behavior may result in dealing with Dask Arrays in a non-spec conforming way (the option does not exist for other array implementations), which could make it hard to write Array API code that would accommodate Dask in these cases. This can come up in other places code depends on shape checking
Right, I guess my question is why getting all indices together into a single array matters.
As this proposal is without a champion, I'll go ahead and close.