array-api
array-api copied to clipboard
Add complex number support to `sqrt`
This PR
- adds complex number support to
sqrt
by documenting special cases. By convention, the square root has a single branch cut, which is defined as the real interval[-infinity, 0)
. - updates the input and output array data types to be any floating-point data type, not just real-valued floating-point data types.
- derives special cases from C99 and tested against NumPy (script below).
import numpy as np
import math
def is_equal_float(x, y):
"""Test whether two floating-point numbers are equal with special consideration for zeros and NaNs.
Parameters
----------
x : float
First input number.
y : float
Second input number.
Returns
-------
bool
Boolean indicating whether two floating-point numbers are equal.
Examples
--------
>>> is_equal_float(0.0, -0.0)
False
>>> is_equal_float(-0.0, -0.0)
True
"""
# Handle +-0:
if x == 0.0 and y == 0.0:
return math.copysign(1.0, x) == math.copysign(1.0, y)
# Handle NaNs:
if x != x:
return y != y
# Everything else, including infinities:
return x == y
def is_equal(x, y):
"""Test whether two complex numbers are equal with special consideration for zeros and NaNs.
Parameters
----------
x : complex
First input number.
y : complex
Second input number.
Returns
-------
bool
Boolean indicating whether two complex numbers are equal.
Examples
--------
>>> import numpy as np
>>> is_equal(complex(np.nan, np.nan), complex(np.nan, np.nan))
True
"""
return is_equal_float(x.real, y.real) and is_equal_float(x.imag, y.imag)
# Strided array consisting of input values and expected values:
values = [
complex(0.0, 0.0), # 0
complex(0.0, 0.0), # 0
complex(-0.0, 0.0), # 1
complex(0.0, 0.0), # 1
complex(1.0, np.inf), # 2
complex(np.inf, np.inf), # 2
complex(np.nan, np.inf), # 3
complex(np.inf, np.inf), # 3
complex(0.0, np.inf), # 4
complex(np.inf, np.inf), # 4
complex(1.0, np.nan), # 5
complex(np.nan, np.nan), # 5
complex(-np.inf, 1.0), # 6
complex(0.0, np.inf), # 6
complex(np.inf, 1.0), # 7
complex(np.inf, 0.0), # 7
complex(-np.inf, np.nan), # 8
complex(np.nan, np.inf), # 8, imaginary component sign unspecified
complex(np.inf, np.nan), # 9
complex(np.inf, np.nan), # 9
complex(np.nan, 1.0), # 10
complex(np.nan, np.nan), # 10
complex(np.nan, np.nan), # 11
complex(np.nan, np.nan) # 11
]
for i in range(len(values)//2):
j = i * 2
v = values[j]
e = values[j+1]
actual = np.sqrt(v)
print('Index: {index}'.format(index=str(i)))
print('Value: {value}'.format(value=str(v)))
print('Actual: {actual}'.format(actual=str(actual)))
print('Expected: {expected}'.format(expected=str(e)))
print('Equal: {is_equal}'.format(is_equal=str(is_equal(actual, e))))
print('\n')
Index: 0
Value: 0j
Actual: 0j
Expected: 0j
Equal: True
Index: 1
Value: (-0+0j)
Actual: 0j
Expected: 0j
Equal: True
Index: 2
Value: (1+infj)
Actual: (inf+infj)
Expected: (inf+infj)
Equal: True
Index: 3
Value: (nan+infj)
Actual: (inf+infj)
Expected: (inf+infj)
Equal: True
Index: 4
Value: infj
Actual: (inf+infj)
Expected: (inf+infj)
Equal: True
Index: 5
Value: (1+nanj)
Actual: (nan+nanj)
Expected: (nan+nanj)
Equal: True
Index: 6
Value: (-inf+1j)
Actual: infj
Expected: infj
Equal: True
Index: 7
Value: (inf+1j)
Actual: (inf+0j)
Expected: (inf+0j)
Equal: True
Index: 8
Value: (-inf+nanj)
Actual: (nan+infj)
Expected: (nan+infj)
Equal: True
Index: 9
Value: (inf+nanj)
Actual: (inf+nanj)
Expected: (inf+nanj)
Equal: True
Index: 10
Value: (nan+1j)
Actual: (nan+nanj)
Expected: (nan+nanj)
Equal: True
Index: 11
Value: (nan+nanj)
Actual: (nan+nanj)
Expected: (nan+nanj)
Equal: True
General comments since I see you start working on ufuncs that have a branch cut 🙂
- The explanation of a branch cut should be moved to top of the section so that all functions can refer/hyperlink to it.
- Are we going to require now all libraries to pick the same branch cut? I suspect this is eventually where we need to land on (a set of de facto choices already implemented by everyone in real life) but I'd want to raise this discussion.
@leofang Re: defn branch cut. I can move to the list of specification definitions.
Re: same branch cuts. Currently, I am approaching this on a case-by-case basis. For sqrt
, the chosen branch cut did not seem, at least to me, controversial. For other transcendentals, the choices are straightforward, as described by Kahan.
However, one possible idea I entertained is that for functions having branch cuts (apart from fcns with unanimous conventions, such as sqrt
) is specifying the branch cuts under a "provisional" status. We'd update the test suite accordingly; however, instead of failing, we could warn.
If we find that, for all practical purposes, array libraries all adhere to the same branch cuts, we could move from provisional to non-provisional in a later revision of the standard.
This still needs the refactoring and marking of branch cut as provisional. I don't see any concerns with getting it merged.
Updated to include a warning documenting that the choice of branch cut is under provisional status and moved the definition of a branch cut to the list of specification definitions.
This PR should be ready for final review.
Updated the PR by adding a discussion concerning branch cuts to design topics and linking back to the document from sqrt
.
This PR should be ready for another review.
Moving special cases to somewhere else can be done, but this should be in a follow-up PR and apply to all relevant functions in the specification, as this is generally applicable throughout the specification.
Created an issue for reorganizing documentation content (see gh-518). The tentative plan is to reorganize content in one fell swoop, rather than begin reorganizing content in a piecemeal fashion beginning with this PR.
As such, as this PR has been discussed in consortium meetings and no objections have been raised concerning the branch cut here or elsewhere, will merge. As this PR contains the design topic on branch cuts, by merging this PR, we'll unblock other PRs which need to refer to this document. Any further changes can be addressed in follow-up PRs.