array-api icon indicating copy to clipboard operation
array-api copied to clipboard

RFC: add `isin` for elementwise set inclusion test

Open lucascolley opened this issue 1 year ago • 5 comments

Prior art

  • NumPy - https://numpy.org/doc/stable/reference/generated/numpy.isin.html
  • CuPy - https://docs.cupy.dev/en/latest/reference/generated/cupy.isin.html
  • Dask - https://docs.dask.org/en/stable/generated/dask.array.isin.html
  • JAX - https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.isin.html
  • PyTorch - https://pytorch.org/docs/stable/generated/torch.isin.html
  • ndonnx - https://ndonnx.readthedocs.io/en/latest/api/ndonnx.additional.html#ndonnx.additional.isin
  • MLX - isin is not present yet

Motivation

This function is used in scikit-learn. They've implemented it in terms of the standard, and that implementation could find a home in array-api-extra: https://github.com/data-apis/array-api-extra/issues/34. @asmeurer suggested there that we should also consider adding this to the standard.

lucascolley avatar Nov 21 '24 23:11 lucascolley

Another potential reason for adding it is that it uses a nontrivial implementation which depends on some heuristics based on the input size.

asmeurer avatar Nov 22 '24 00:11 asmeurer

Thanks @lucascolley. I've added ndonnx (which has it) and MLX (which doesn't) to the issue description.

This seems like a very reasonable proposal to me. Implementing isin in terms of other primitives in the standard is a little complex indeed.

The return type should always be a boolean array. The NumPy docs say it can be a bool for a single input element, but that's actually a bug in the docs. I checked NumPy, JAX, and PyTorch and all return a 0-D array.

rgommers avatar Nov 22 '24 09:11 rgommers

The thing to discuss here is what keywords are desired I think. NumPy and Dask use:

def isin(element, test_elements, assume_unique=False, invert=False, *, kind=None)

The private scikit-learn implementation here is:

def isin(element, test_elements, xp, assume_unique=False, invert=False)

JAX:

def isin(element, test_elements, assume_unique=False, invert=False, *, method='auto')

PyTorch:

def isin(elements, test_elements, *, assume_unique=False, invert=False)

ndonnx:

def isin(x: Array, items: Sequence[Scalar]) -> Array

The assume_unique and invert keywords seem easy to support and useful. So this should probably work:

def isin(x: Array, test_elements: Array, /, *, assume_unique : bool = False, invert : bool =False) -> Array[bool]

The type of test_elements is a bit TBD, could be a union between arrays and sequences perhaps?

rgommers avatar Nov 28 '24 16:11 rgommers

We discussed this in the community meeting today. A summary with a couple of points to follow up on:

  • In general, folks were 👍🏼 on adding isin
  • For the second argument, accepting arrays and scalars but not sequences seemed preferred.
  • Some discussion about promotion behavior, this needs to be specified. Since the semantics of isin are element-wise comparison like, it is probably a good idea to match what == does.
  • For the second argument: limit to 1-D, or accept any shape and then reshape to 1-D? NumPy does the latter.

rgommers avatar Nov 28 '24 21:11 rgommers

A PR adding isin to the specification is now up for review: https://github.com/data-apis/array-api/pull/959

kgryte avatar Jun 12 '25 10:06 kgryte