array-api-compat icon indicating copy to clipboard operation
array-api-compat copied to clipboard

ENH: use torch.clamp for wrapped_torch.clip

Open ev-br opened this issue 3 months ago • 1 comments

Closes gh-350

Otherwise, the version which emulates "clip" fails with torch.vmap.


This patch is less innocuous than it looks, because it changes the promotion rules. Previously, min and max were not upcasting the result:

In [9]: xpn = array_namespace(np.arange(3))

In [10]: xpn.clip(xpn.arange(5, dtype=xpn.int8), 2, xpn.asarray(3.0))
Out[10]: array([2, 2, 2, 3, 3], dtype=int8)

and now they are:

In [11]: xpt = array_namespace(torch.ones(3))

In [12]: xpt.clip(xpt.arange(5, dtype=xpt.int8), 2, xpt.asarray(3.0))
Out[12]: tensor([2., 2., 2., 3., 3.])

OTOH, the wording in the spec, https://data-apis.org/array-api/draft/API_specification/generated/array_api.clip.html is

min ... should have the same data type as x.

and

If either min or max is an array having a different data type than x, behavior is unspecified and thus implementation-dependent.

ev-br avatar Sep 08 '25 18:09 ev-br

cc/fyi @mdhaber : IIRC scipy.stats is a big user of the [()] empty tuple indexing. Apparently, it is problematic for torch.vmap, especially if used on the left of the = assignment. This PR and the issue it closes is triggered by this line: https://github.com/data-apis/array-api-compat/blob/main/array_api_compat/common/_aliases.py#L424

ev-br avatar Nov 25 '25 15:11 ev-br

Yes, but I don't recall putting [()] on the left of an equals sign. Typically, when code in scipy.stats mutates a 0-d array, it's a special case of mutating a multidimensional array. And usually, that's done with a logical mask generated by some sort of comparison operator, so the 0-d case ends up indexing with a 0-d boolean array, not an empty tuple.

Almost every scipy.stats function uses [()] to the right of an equals sign / return statement to follow the precedent set by NumPy - return scalars instead of 0-d arrays at almost every opportunity. Now that we have to consider array-api-strict, empty tuple __getitem__ is only ever used on 0-d arrays; e.g. like x[() if x.ndim == 0 else x. Fine with me to get rid of all these uses when NumPy makes it less difficult to follow the standard (mdhaber/numpy#2)! But hopefully getting with [()] is not problematic in the meantime? Having to special case that for 0-d arrays is already verbose enough!

mdhaber avatar Dec 01 '25 05:12 mdhaber