array-api
array-api copied to clipboard
feat: add `bincount` to the specification
This PR
- resolves https://github.com/data-apis/array-api/issues/812 by adding
bincountto the specification for counting the number of occurrences of each element in an input integer array. - based on comparison data, only supports a
minlengthkeyword argument. - allows
weightsto be both a positional and keyword argument. - allows
weightsto have any numeric data type, including complex. - specifies that, when
weightsis not provided, the output data type must be an integer data type. The data type rules follow other statistical functions (e.g.,sum), where a minimum precision is required. - specifies that, when
weightsis provided, the output data type must have the same data type asweights. - specifies that an input array
xshould (not must) be a one-dimensional array. TensorFlow allowsxto be multi-dimensional, and this PR chooses to provide wiggle room to allow supporting multi-dimensional arrays. One reason this isn't commonly supported in other libraries is thatbincounthas a data-dependent output shape; however, TF supports kwargs which allow specifying a static output shape, thus allowingbincountto generalize to multiple dimensions. - includes an admonition that
bincounthas a data-dependent output shape and thus certain libraries are allowed to omit this function if too difficult to implement. This follows similar practice for other APIs having data-dependent output shapes (e.g.,unique*). - specifies that, if
xcontains negative values, behavior is unspecified and thus implementation-defined. NumPy raises an exception, while JAX clips. - specifies that
weightshave the same shape asx; although, IMO, this restriction is not necessary and could be relaxed to broadcast-compatibility. - specifies that the returned array should (not must) have shape
(N,), whereN = max(xp.max(x)+1, minlength). The use of should is intentional, in order to allow libraries such as JAX and TF to support other keyword arguments which may constrain the output shape. - specifies that the default value of
minlengthmust be0. According to docs, both CuPy and TF use a default ofNone.
Questions
- Are we okay with allowing
weightsto be complex? This does not appear to be supported in NumPy (ref: https://github.com/numpy/numpy/issues/23313 and https://github.com/numpy/numpy/issues/16903); however, there isn't a technical reason why weights cannot be complex, as summation is well-defined for complex numbers. - Similar to
sumand other statistical functions, should we support an outputdtypekwarg (e.g., in order to support overriding the default integer output type behavior)? - Are we okay with
weightsbeing both positional and a keyword argument?
The shape of the output array for this function depends on the data values in
x; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) can find this function difficult to implement without knowing the values inx. Accordingly, such libraries may choose to omit this function. See :ref:data-dependent-output-shapessection for more details.
Why not allow Jax to implement this function by adding an optional length argument and making it mandatory for this function to be in the Array API when length is provided? Most algorithms are amenable to that. Otherwise, you'd have to write Jax versions of the same algorithm (yuck).
- Both JAX and TensorFlow allow
xto be multi-dimensional
I don't think this is correct in the case of JAX. I confirmed that this errors in v0.5.x and v0.6.x:
import jax.numpy as jnp
jnp.bincount(jnp.arange(9).reshape(3, 3))
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[<ipython-input-1-1852179904>](https://localhost:8080/#) in <cell line: 0>()
2 import jax.numpy as jnp
3 print(jax.__version__)
----> 4 jnp.bincount(jnp.arange(9).reshape(3, 3))
[/usr/local/lib/python3.11/dist-packages/jax/_src/numpy/lax_numpy.py](https://localhost:8080/#) in bincount(x, weights, minlength, length)
2994 raise TypeError(f"x argument to bincount must have an integer type; got {_dtype(x)}")
2995 if np.ndim(x) != 1:
-> 2996 raise ValueError("only 1-dimensional input supported.")
2997 minlength = core.concrete_or_error(operator.index, minlength,
2998 "The error occurred because of argument 'minlength' of jnp.bincount.")
ValueError: only 1-dimensional input supported.
specifies that an input array x should (not must) be a one-dimensional array. Both JAX and TensorFlow allow x to be multi-dimensional, and this PR chooses to provide wiggle room to allow supporting multi-dimensional arrays. One reason this isn't commonly supported in other libraries is that bincount has a data-dependent output shape; however, both JAX and TF support kwargs which allow specifying a static output shape, thus allowing bincount to generalize to multiple dimensions.
It maybe nice to treat additional dimensions as broadcasted dimensions like e.g., matrix_transpose. That is, suppose x has shape (*xs, xn) and you want to return length bins, you could return an array having shape (*xs, length)? This is just the broadcasted generalization of the 1-dimensional case.
Thanks @kgryte for the detailed proposal. This function is heavily used and present everywhere, so it makes sense to add from that perspective. The main question I have at the moment is whether there is a good alternative for bincount that isn't suffering from the value-dependent issue.
The function itself is pretty specific; for it to work you have to shift the values to a non-negative range just above zero. I think that that's usually not done; it's more common to use something like histogram or scipy.stats.binned_statistic in those cases. bincount is usually used for distributions of integers that are already in the (0, N) range with N not very large (otherwise output size explodes).
Are we okay with allowing
weightsto be complex?
No. This seems super niche, and it's not supported by NumPy - so no reason to even consider this I'd think.
Are we okay with
weightsbeing both positional and a keyword argument?
I'd vote for keyword-only, since it's a very descriptive name and there's no real reason to use positional-only as far as I can tell.
JAX supporting 1-dimensional arrays.
@jakevdp Would be good to update the docstring then for bincount, as currently it suggests that N-dimensional support is present. It is also not clear why JAX's docs state that the array must consists of positive integers, rather than nonnegative integers.
I'd vote for keyword-only, since it's a very descriptive name and there's no real reason to use positional-only as far as I can tell.
@rgommers I am fine making the change to kwarg-only for guaranteed portability. sklearn includes both positional and kwarg usage, with the latter being more predominant. Similarly, from a search on sourcegraph, kwarg usage is more common, although positional usage of np.bincount is not uncommon.
@jakevdp Would be good to update the docstring then for
bincount
Thanks for pointing that out – updated in https://github.com/jax-ml/jax/pull/29441.
Slight preference for weights as keyword only. It is easy enough to update in scikit-learn.
As a user it is annoying/tedious if different libraries require different treatment. The whole point of array API is to have something uniform instead of maintaining a big bunch of if statements. From that point of view it would be nice to have something that works for jax as well. This would mean making length a argument mentioned in the standard. Is it possible to make a generic recommendation for what to pass as value. At least my first reaction to "you have to provide length was "how would I know what it should be, can't you work it out for me far better than I can?" But maybe max(a) and len(a) cover the vast majority of cases for naive users/get people started and then they can ponder if there is a better value? Because if it is that easy and it would remove the need to special case libraries like jax ... that might be a tradeoff worth making? Or am I missing something?
My takeaway from the discussion in the community meeting was similar to the question @betatim asked above. Can actually be split into two:
- Should we add a
lengthkeyword? Seems to be the most common use case, and seems to make more sense thanminlength(the wayminlengthis often used is "given me lengthN, rather than>=N). - Can we leave out
minlength? I.e., is it actually useful separately fromlength?