array-api icon indicating copy to clipboard operation
array-api copied to clipboard

`diff` to allow Python scalar or 0d array `append` and `prepend`?

Open mdhaber opened this issue 9 months ago • 3 comments

The specification of diff requires append and prepend to be arrays with the same shape as the first argument except along axis.

array-api-strict currently does not:

import array_api_strict as xp
xp.diff(xp.ones((2, 3)), append=xp.asarray(10), axis=0)
# Array([[0., 0., 0.],
#        [9., 9., 9.]], dtype=array_api_strict.float64)
xp.diff(xp.ones((2, 3)), append=xp.asarray(10), axis=1)
# Array([[0., 0., 9.],
#        [0., 0., 9.]], dtype=array_api_strict.float64)

Like NumPy, CuPy, and JAX, it expands the value as necessary (and follows the same promotion rules as other functions. I see that following the usual promotion rules was already tabled, so I won't bring that up again right now.)

This suggest that the operation can be well-defined, and I think it can be useful (e.g. prepend 0 and append shape[axis]). It would be even more useful to accept Python scalars, following the example of clip (which also requires array min/max to have the same type as x, but allows Python scalars). This would avoid requiring the user to do something like:

append = xp.full(x.shape[:axis] + (1,) + x.shape[axis:], append, dtype=xp.result_type(x, prepend, append))

mdhaber avatar Apr 18 '25 16:04 mdhaber

As noted in https://github.com/data-apis/array-api/pull/791, scalar kwargs support is not supported in PyTorch, which was a primary motivation for restricting portable behavior to prepend and append being arrays.

kgryte avatar Apr 18 '25 20:04 kgryte

Also ref: https://github.com/data-apis/array-api/issues/784

kgryte avatar Apr 18 '25 20:04 kgryte

Thanks for those references; good to see it was considered.

I understand that when a library doesn't support a feature that the rest do, the standard needs to strike a balance between importance of the feature and difficulty of adding support. Features like uint16-uint64 support, a size attribute, negative step sizes for slices, etc. are very challenging for for PyTorch to add and impossible for array-api-compat to compensate for within its current scope, but the standard includes them because they are very important. (Thank goodness!)

This is a lot less important than those features, but it is also much easier to add. https://github.com/data-apis/array-api-compat/issues/271 and https://github.com/data-apis/array-api-compat/issues/274 have several other examples of PyTorch lacking support for scalars (and these are just the ones I've run across), but the standard decided to include them. I'd suggest that this could be similar.

mdhaber avatar Apr 18 '25 20:04 mdhaber