array-api-compat icon indicating copy to clipboard operation
array-api-compat copied to clipboard

`sign` special case implementations

Open mdhaber opened this issue 1 year ago • 13 comments

According to v2022.12 (and v2023.12) of the array API standard, the special cases of sign include:

For real-valued operands... If x_i is NaN, the result is NaN.

However,

from array_api_compat import numpy as np, cupy as cp, torch
np.sign(np.asarray(np.nan))  # nan
cp.sign(cp.asarray(cp.nan))  # array( 0.000e+00)
torch.sign(torch.asarray(torch.nan))  # tensor(0.)

There may be other special cases that are not yet implemented. I haven't done a complete review, but I noticed that torch gives an error when the input is complex.

torch.sign(torch.asarray(1+1j))
# RuntimeError: Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.

mdhaber avatar May 09 '24 22:05 mdhaber

I haven't tried to do any special-cases workarounds here yet. I guess if this is causing issues for you we can add a workaround.

asmeurer avatar May 10 '24 01:05 asmeurer

We definitely should add complex support to torch.sign. It sounds like that will just be an easy wrapper of torch.sgn.

asmeurer avatar May 10 '24 01:05 asmeurer

Yeah I actually did run into this special case while converting code from NumPy to array API. The purpose of the code is to produce a different status flag that depends on whether an element is positive, negative, zero, or NaN, so no wonder I ran into it. So it would be helpful to add the special case, but support for 2023.12 would be higher priority to me.

mdhaber avatar May 10 '24 02:05 mdhaber

We can add it. The main concern with adding special-case handling is it means adding a mask to the function, so it could be a minor hit to performance. I would at least make sure there are upstream issues about this to the libraries that don't implement it correctly.

asmeurer avatar May 10 '24 04:05 asmeurer

The main concern with adding special-case handling is it means adding a mask to the function, so it could be a minor hit to performance.

I think that won't be a minor hit. For element-wise functions, adding isnan checks and masking may slow things down by 2x or more.

In general, these special cases have not been well validated yet, so I'd be quite reluctant to assume they are all correctly specified or support them in the compat layer.

In this case, it may be possible to fix in NumPy. There's a bunch of special-casing to make sign(nan) return nan, which looks like old code and disagrees with C99's signbit on purpose:

https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/_core/src/umath/loops.c.src#L1341-L1350

https://github.com/numpy/numpy/blob/2a9b9134270371b43223fc848b753fceab96b4a5/numpy/_core/src/umath/loops.c.src#L2262-L2271

Changing it would be a minor backwards-compat break, but probably an improvement. Much larger changes were made to np.sign recently in gh-https://github.com/numpy/numpy/pull/25441, however the sign(nan) = nan was not discussed there. It does seem like sign deviates from signbit for no good reason; it was probably an ad-hoc pre-C99 decision.

rgommers avatar May 10 '24 10:05 rgommers

NumPy is the one that agrees with what the array API says. If we think sign(nan) should not return NaN, then the array API should change.

asmeurer avatar May 10 '24 15:05 asmeurer

If we think sign(nan) should not return NaN, then the array API should change.

I'm honestly not sure. These special cases are a pain, it requires a lot of time investment to figure out what was discussed before, and why different libraries end up with different return values. I suspect the following:

  • PyTorch and CuPy here simply use the C/C++ version of signbit. Special cases are defined there, see e.g. std::signbit which says signbit(+nan) = false, signbit(-nan) = true
  • NumPy returns nan because someone way back when decided this was needed, and hence there is some custom logic
  • For the array API standard:
    • sign was added very early on (https://github.com/data-apis/array-api/pull/36) with some special cases but without a nan special rule
    • later that rule was added somewhere, and also more recently a separate signbit function was added (https://github.com/data-apis/array-api/pull/705)
    • surely there is some comment somewhere about the difference between sign and signbit, but it's hard to piece together.

rgommers avatar May 10 '24 18:05 rgommers

I think one normal rule of NaNs is that they are propagated unless there is a clear reason why not. sign is very different from signbit after all as that is defined to return a bool and additionally maps to an implementation detail of IEEE float representation.

Further, returning NaN preserves the full "partial order" (e.g. in C++20) of value <=> 0 (less, equal, greater, and unordered).

So IMO, NumPy does the right thing unless there is a good argument why typical use-cases would expect a 0 return for NaN and it sounds like the report here is a use-case where you want the full partial order to be preserved!

TBH, I would lean towards torch should fix this, but if there is a good reason why they don't want to (what is it?), then it has to stay undefined which seems unfortunate for the actual use-case above.

seberg avatar May 10 '24 19:05 seberg

I found this old discussion https://mail.python.org/archives/list/[email protected]/thread/A2JFHOZZOF634CNZ7E27THQEBU4EZFTS/#F3KS7QWVPIXSYB7CSSY37OXYM4JVZTZQ

sign and signbit are completely different things for complex numbers, as I pointed out at https://mail.python.org/archives/list/[email protected]/message/VBYOVSTN2GTBPEJ3OPDS2S5DLPQJFFX3/ It's probably not a coincidence that torch also doesn't define complex sign.

It does seem valuable to figure this out, since there's at least one real-world use-case. Maybe we should move this discussion to the array-api repo.

asmeurer avatar May 10 '24 19:05 asmeurer

I have a fix at https://github.com/data-apis/array-api-compat/pull/137, which we can at least use to check the performance implications.

asmeurer avatar May 10 '24 20:05 asmeurer

By the way, you can see the other special cases that are not being followed in the xfails files (we do not yet attempt to fix any of them, and most are on operators which can't be fixed anyways):

torch https://github.com/data-apis/array-api-compat/blob/376038ed9f4337cdec78f21a5ccb3e2b6d948786/torch-xfails.txt#L89-L179

cupy https://github.com/data-apis/array-api-compat/blob/376038ed9f4337cdec78f21a5ccb3e2b6d948786/cupy-xfails.txt#L61-L169

numpy https://github.com/data-apis/array-api-compat/blob/376038ed9f4337cdec78f21a5ccb3e2b6d948786/numpy-xfails.txt#L18-L41

I will say that even though there are quite a few of these, the sign special case seems to stand out. Almost all the rest seem to have to do with handling -0 correctly.

asmeurer avatar May 10 '24 20:05 asmeurer

@rgommers The addition of the special case for NaN handling in sign comes from https://github.com/data-apis/array-api/pull/556.

The specification is correct on this, as returning NaN follows naturally from the definition of the signum function where

$$\textrm{sgn\ } x = \frac{x}{|x|}$$

and where NaN/NaN follows the special cases for division, thus ensuring arithmetic consistency.

This was not a pre-C99 oversight and was intentional. We shouldn't expect the signum and signbit functions to return equivalent results, and NumPy is correct here. It would perhaps have been better if NumPy had chosen signum as the name rather than sign to make this delineation in behavior more clear.

kgryte avatar May 11 '24 17:05 kgryte

Since its maybe interesting. Turns out getting NaNs may make this faster on the GPU: https://github.com/cupy/cupy/issues/8327

seberg avatar May 18 '24 08:05 seberg

Just thought I'd mention that I ran into this again in a different context.

mdhaber avatar Aug 25 '24 19:08 mdhaber