array-api icon indicating copy to clipboard operation
array-api copied to clipboard

How to infer appropriate `dtype` from `uint` to `int` and `float` to `complex`?

Open 34j opened this issue 1 year ago • 3 comments

I would like to compute $f(x) := xi$, $g(y) := y - 1$ where $i$ is an imaginary number, $x$ is float and $y$ is uint, using array-api. However, I am not sure what is the best way to implement it. Following the type promotion rules

def f(x: xp.array) -> xp.array:
	return x * xp.array(1j, dtype=xp.complex64 if x.dtype == xp.float32 else xp.complex128)

def g(x: xp.array) -> xp.array:
	return x - xp.array(1, dtype=xp.int16 if x.dtype == xp.uint8 else xp.int32 if x.dtype == xp.uint16 else  xp.int64)

This seems too redundant. What is the proper way to do this?

34j avatar Nov 25 '24 06:11 34j

see gh-841

lucascolley avatar Nov 25 '24 23:11 lucascolley

That issue covers the complex case. Once it is fixed, x*1j should work as you would expect when x has a floating-point dtype.

For the second cast, the standard works like this:

>>> import array_api_strict as xp
>>> xp.asarray([0], dtype=xp.uint8) - 1
Array([255], dtype=array_api_strict.uint8)

The difference is your example would make the resulting type int16. The type promotion rules for scalars says that Python scalars first cast to the type of the array, so <uint8 array> - <python int> will always produces a uint8 array. In this particular case, the result is not actually defined (see the note at https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.neg.html#neg, together with the note saying x - y == x + (-y) at https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.sub.html#sub; small nit, __sub__ should probably just state this fact more directly, since it's only undefined when x - y is negative).

The rule that scalars always cast to the same dtype as the array is not something that should change, so you'd want some other way to spell y - 1 so that it upcasts to a signed dtype.

I think the astype suggestion in https://github.com/data-apis/array-api/issues/841#issuecomment-2392032433 would be the cleanest way to do this. If it were implemented, you could write

xp.astype(x, 'signed') - 1

where xp.astype(<uint8>, 'signed') would return an int16 array (and would error for uint64, but that's always a tricky dtype to deal with in the context of type promotion).

The astype improvements idea should be split out into its own issue. I doubt it would be implemented for the 2024 standard release, since it hasn't even been fleshed out yet (though it's not impossible). The complex-scalar-to-float-array issue will definitely be fixed for 2024.

asmeurer avatar Nov 26 '24 23:11 asmeurer

I think the astype suggestion in https://github.com/data-apis/array-api/issues/841#issuecomment-2392032433 would be the cleanest way to do this. If it were implemented

I doubt it would be implemented for the 2024 standard release, since it hasn't even been fleshed out yet (though it's not impossible)

gh-848 😉 (are there complexities which I haven't thought about? A review would be appreciated!)

lucascolley avatar Nov 26 '24 23:11 lucascolley