array-api-compat
array-api-compat copied to clipboard
`sign` special case implementations
According to v2022.12 (and v2023.12) of the array API standard, the special cases of sign include:
For real-valued operands... If x_i is NaN, the result is NaN.
However,
from array_api_compat import numpy as np, cupy as cp, torch
np.sign(np.asarray(np.nan)) # nan
cp.sign(cp.asarray(cp.nan)) # array( 0.000e+00)
torch.sign(torch.asarray(torch.nan)) # tensor(0.)
There may be other special cases that are not yet implemented. I haven't done a complete review, but I noticed that torch gives an error when the input is complex.
torch.sign(torch.asarray(1+1j))
# RuntimeError: Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.
I haven't tried to do any special-cases workarounds here yet. I guess if this is causing issues for you we can add a workaround.
We definitely should add complex support to torch.sign. It sounds like that will just be an easy wrapper of torch.sgn.
Yeah I actually did run into this special case while converting code from NumPy to array API. The purpose of the code is to produce a different status flag that depends on whether an element is positive, negative, zero, or NaN, so no wonder I ran into it. So it would be helpful to add the special case, but support for 2023.12 would be higher priority to me.
We can add it. The main concern with adding special-case handling is it means adding a mask to the function, so it could be a minor hit to performance. I would at least make sure there are upstream issues about this to the libraries that don't implement it correctly.
The main concern with adding special-case handling is it means adding a mask to the function, so it could be a minor hit to performance.
I think that won't be a minor hit. For element-wise functions, adding isnan checks and masking may slow things down by 2x or more.
In general, these special cases have not been well validated yet, so I'd be quite reluctant to assume they are all correctly specified or support them in the compat layer.
In this case, it may be possible to fix in NumPy. There's a bunch of special-casing to make sign(nan) return nan, which looks like old code and disagrees with C99's signbit on purpose:
https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/_core/src/umath/loops.c.src#L1341-L1350
https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/_core/src/umath/loops.c.src#L2262-L2271
Changing it would be a minor backwards-compat break, but probably an improvement. Much larger changes were made to np.sign recently in gh-https://github.com/numpy/numpy/pull/25441, however the sign(nan) = nan was not discussed there. It does seem like sign deviates from signbit for no good reason; it was probably an ad-hoc pre-C99 decision.
NumPy is the one that agrees with what the array API says. If we think sign(nan) should not return NaN, then the array API should change.
If we think sign(nan) should not return NaN, then the array API should change.
I'm honestly not sure. These special cases are a pain, it requires a lot of time investment to figure out what was discussed before, and why different libraries end up with different return values. I suspect the following:
- PyTorch and CuPy here simply use the C/C++ version of signbit. Special cases are defined there, see e.g. std::signbit which says
signbit(+nan) = false,signbit(-nan) = true - NumPy returns
nanbecause someone way back when decided this was needed, and hence there is some custom logic - For the array API standard:
signwas added very early on (https://github.com/data-apis/array-api/pull/36) with some special cases but without a nan special rule- later that rule was added somewhere, and also more recently a separate
signbitfunction was added (https://github.com/data-apis/array-api/pull/705) - surely there is some comment somewhere about the difference between
signandsignbit, but it's hard to piece together.
I think one normal rule of NaNs is that they are propagated unless there is a clear reason why not.
sign is very different from signbit after all as that is defined to return a bool and additionally maps to an implementation detail of IEEE float representation.
Further, returning NaN preserves the full "partial order" (e.g. in C++20) of value <=> 0 (less, equal, greater, and unordered).
So IMO, NumPy does the right thing unless there is a good argument why typical use-cases would expect a 0 return for NaN and it sounds like the report here is a use-case where you want the full partial order to be preserved!
TBH, I would lean towards torch should fix this, but if there is a good reason why they don't want to (what is it?), then it has to stay undefined which seems unfortunate for the actual use-case above.
I found this old discussion https://mail.python.org/archives/list/[email protected]/thread/A2JFHOZZOF634CNZ7E27THQEBU4EZFTS/#F3KS7QWVPIXSYB7CSSY37OXYM4JVZTZQ
sign and signbit are completely different things for complex numbers, as I pointed out at https://mail.python.org/archives/list/[email protected]/message/VBYOVSTN2GTBPEJ3OPDS2S5DLPQJFFX3/ It's probably not a coincidence that torch also doesn't define complex sign.
It does seem valuable to figure this out, since there's at least one real-world use-case. Maybe we should move this discussion to the array-api repo.
I have a fix at https://github.com/data-apis/array-api-compat/pull/137, which we can at least use to check the performance implications.
By the way, you can see the other special cases that are not being followed in the xfails files (we do not yet attempt to fix any of them, and most are on operators which can't be fixed anyways):
torch https://github.com/data-apis/array-api-compat/blob/376038ed9f4337cdec78f21a5ccb3e2b6d948786/torch-xfails.txt#L89-L179
cupy https://github.com/data-apis/array-api-compat/blob/376038ed9f4337cdec78f21a5ccb3e2b6d948786/cupy-xfails.txt#L61-L169
numpy https://github.com/data-apis/array-api-compat/blob/376038ed9f4337cdec78f21a5ccb3e2b6d948786/numpy-xfails.txt#L18-L41
I will say that even though there are quite a few of these, the sign special case seems to stand out. Almost all the rest seem to have to do with handling -0 correctly.
@rgommers The addition of the special case for NaN handling in sign comes from https://github.com/data-apis/array-api/pull/556.
The specification is correct on this, as returning NaN follows naturally from the definition of the signum function where
$$\textrm{sgn\ } x = \frac{x}{|x|}$$
and where NaN/NaN follows the special cases for division, thus ensuring arithmetic consistency.
This was not a pre-C99 oversight and was intentional. We shouldn't expect the signum and signbit functions to return equivalent results, and NumPy is correct here. It would perhaps have been better if NumPy had chosen signum as the name rather than sign to make this delineation in behavior more clear.
Since its maybe interesting. Turns out getting NaNs may make this faster on the GPU: https://github.com/cupy/cupy/issues/8327
Just thought I'd mention that I ran into this again in a different context.