glass icon indicating copy to clipboard operation
glass copied to clipboard

Array API: Typing

Open ntessore opened this issue 1 year ago • 10 comments

Add your issue here

We want to use the heroic efforts by @paddyroddy to have a fully typed GLASS to make the transition to the Array API easier. For example, if we had a full Array API typing implementation, we would trivially find issues such as atan vs arctan with a run of mypy.

@nstarman maintains a great resource in https://github.com/nstarman/array_api. Maybe we can look into how we can use this to everyone's benefit?

ntessore avatar Nov 13 '24 21:11 ntessore

I spent some time on this, and we really can't do anything till there is a typing Protocol (or a superclass for typing purposes which is subclassed by every Array API abiding provider library (quite excessive)) in the array API standard itself.

@nstarman's (👋🏽 you might remember me from removestar) https://github.com/data-apis/array-api/pull/589 is very promising and hopefully will be merged soon 🤞🏽

The PR is almost ready (and approved) and they will be discussing the same in array API's community meeting tomorrow.

Once that is merged and released, we can expect the new Protocol to solve our typing issues.

Marking this issue as blocked till the Array Protocol is a part of the Array API standard.

Saransh-cpp avatar Nov 27 '24 15:11 Saransh-cpp

https://github.com/data-apis/array-api-typing 👀

Saransh-cpp avatar Dec 02 '24 09:12 Saransh-cpp

Hi! https://github.com/data-apis/array-api/pull/589 is probably never going to be merged, but it has evolved into https://github.com/data-apis/array-api-typing, where we're slowly porting over / fixing parts of https://github.com/nstarman/array_api.

nstarman avatar Dec 05 '24 06:12 nstarman

We should create a custom typing Protocol for now.

Starting point (Array and ArrayNamespace Protocols) - https://github.com/glass-dev/glass/blob/f993b61be34cb12ed31f041dc981fd6a6d854302/glass/grf/_core.py#L17-L47

Saransh-cpp avatar Apr 04 '25 10:04 Saransh-cpp

Turns out, this is not as simple as it looked. I implemented custom Protocols, and even revived https://github.com/nstarman/array_api to work with newer versions of the array API standard, but it does not seem to solve our problem.

Taking an example of one of the essential methods, __array_namespace__:

NumPy defines the method as -

def __array_namespace__(self, /, *, api_version: _ArrayAPIVersion | None = None) -> ModuleType: ...

where -

from typing import Literal as L

_ArrayAPIVersion: TypeAlias = L["2021.12", "2022.12", "2023.12"]

which is consistent with array API standard 2024.12 (however, NumPy does not entirely support 2024.12 yet).

JAX defines the same method as -

def __array_namespace__(self, *, api_version: None | str = ...) -> ModuleType: ...

which is consistent with array API standard 2023.12 (no 2024.12 support yet).

Therefore, these methods (along with a few others) differ in type signatures, making it impossible to use Protocols without any standardization efforts.

We should just use -

NDArray[...] | ArrayLike

till the time there is an official typing implementation. Marking this as blocked again 😞

Saransh-cpp avatar Apr 07 '25 14:04 Saransh-cpp

We should just use -

NDArray[...] | ArrayLike

till the time there is an official typing implementation. Marking this as blocked again 😞

I'm inclined to agree with this. I feel that we are working on Array API whilst it is under active development. If we simply wait for a bit this problem will be solved for us.

paddyroddy avatar Apr 07 '25 14:04 paddyroddy

Looking at this again, the error above looks fixable by overloading, but here is an MWE for problems with NumPy's typing (SuperArray.__add__'s typing here strictly follows the array API standard) -

import numpy as np
import typing

Array = typing.TypeVar("Array", bound="SuperArray")

class SuperArray(typing.Protocol):
    def __add__(self: Array, other: int | float | complex | Array, /) -> Array: ...

def myfunc(array: SuperArray) -> SuperArray:
    return array

myfunc(np.eye(3))
(.env) saransh@Saranshs-MacBook-Pro glass % python3 -m mypy test_mypy.py --show-error-end
test_mypy.py:15:8:15:16: error: Argument 1 to "myfunc" has incompatible type "ndarray[tuple[int, ...], dtype[float64]]"; expected "SuperArray"  [arg-type]
test_mypy.py:15:8:15:16: note: Following member(s) of "ndarray[tuple[int, ...], dtype[float64]]" have conflicts:
test_mypy.py:15:8:15:16: note:     Expected:
test_mypy.py:15:8:15:16: note:         def __add__(self, int | float | complex | ndarray[tuple[int, ...], dtype[float64]], /) -> ndarray[tuple[int, ...], dtype[float64]]
test_mypy.py:15:8:15:16: note:     Got:
test_mypy.py:15:8:15:16: note:         @overload
test_mypy.py:15:8:15:16: note:         def __add__(self, int | numpy.bool[builtins.bool], /) -> ndarray[tuple[int, ...], dtype[float64]]
test_mypy.py:15:8:15:16: note:         @overload
test_mypy.py:15:8:15:16: note:         def __add__(self, _SupportsArray[dtype[numpy.bool[builtins.bool]]] | _NestedSequence[_SupportsArray[dtype[numpy.bool[builtins.bool]]]] | builtins.bool | _NestedSequence[builtins.bool], /) -> ndarray[tuple[int, ...], dtype[float64]]
test_mypy.py:15:8:15:16: note:         @overload
test_mypy.py:15:8:15:16: note:         def __add__(self, _SupportsArray[dtype[floating[_64Bit] | floating[_32Bit] | floating[_16Bit] | integer[Any] | numpy.bool[builtins.bool]]] | _NestedSequence[_SupportsArray[dtype[floating[_64Bit] | floating[_32Bit] | floating[_16Bit] | integer[Any] | numpy.bool[builtins.bool]]]] | float | int | _NestedSequence[float | int], /) -> ndarray[tuple[int, ...], dtype[float64]]
test_mypy.py:15:8:15:16: note:         @overload
test_mypy.py:15:8:15:16: note:         def __add__(self, _SupportsArray[dtype[floating[_64Bit]]] | _NestedSequence[_SupportsArray[dtype[floating[_64Bit]]]], /) -> ndarray[tuple[int, ...], dtype[float64]]
test_mypy.py:15:8:15:16: note:         @overload
test_mypy.py:15:8:15:16: note:         def __add__(self, _SupportsArray[dtype[complexfloating[_64Bit, _64Bit]]] | _NestedSequence[_SupportsArray[dtype[complexfloating[_64Bit, _64Bit]]]], /) -> ndarray[tuple[int, ...], dtype[complex128]]
test_mypy.py:15:8:15:16: note:         @overload
test_mypy.py:15:8:15:16: note:         def __add__(self, _SupportsArray[dtype[numpy.bool[builtins.bool]] | dtype[integer[Any]] | dtype[floating[Any]]] | _NestedSequence[_SupportsArray[dtype[numpy.bool[builtins.bool]] | dtype[integer[Any]] | dtype[floating[Any]]]] | builtins.bool | int | float | _NestedSequence[builtins.bool | int | float], /) -> ndarray[tuple[int, ...], dtype[floating[Any]]]
test_mypy.py:15:8:15:16: note:         @overload
test_mypy.py:15:8:15:16: note:         def __add__(self, _SupportsArray[dtype[numpy.bool[builtins.bool]] | dtype[integer[Any]] | dtype[floating[Any]] | dtype[complexfloating[Any, Any]]] | _NestedSequence[_SupportsArray[dtype[numpy.bool[builtins.bool]] | dtype[integer[Any]] | dtype[floating[Any]] | dtype[complexfloating[Any, Any]]]] | builtins.bool | int | float | complex | _NestedSequence[builtins.bool | int | float | complex], /) -> ndarray[tuple[int, ...], dtype[complexfloating[Any, Any]]]
test_mypy.py:15:8:15:16: note:         @overload
test_mypy.py:15:8:15:16: note:         def __add__(self, _SupportsArray[dtype[numpy.bool[builtins.bool]] | dtype[number[Any, int | float | complex]]] | _NestedSequence[_SupportsArray[dtype[numpy.bool[builtins.bool]] | dtype[number[Any, int | float | complex]]]] | builtins.bool | int | float | complex | _NestedSequence[builtins.bool | int | float | complex], /) -> ndarray[tuple[int, ...], dtype[number[Any, int | float | complex]]]
test_mypy.py:15:8:15:16: note:         @overload
test_mypy.py:15:8:15:16: note:         def __add__(self, _SupportsArray[dtype[object_]] | _NestedSequence[_SupportsArray[dtype[object_]]], /) -> Any
Found 1 error in 1 file (checked 1 source file)

Surprisingly, JAX works with the Protocols I defined (that is, with the exact type signatures of the array API standard). Maybe I should raise this in NumPy.

Saransh-cpp avatar Apr 07 '25 15:04 Saransh-cpp

Hm, could you summarise the issue here? In any case, I'd say we ignore the versioned __array_namespace__() since that is a minefield still. In which case, the snippet above seems to work fine with some small tweaks? (At least in mypy; pyright complains about the covariance of numpy's DType, as well as ModuleType for ArrayNamespace, which I think are both valid complaints.)

from typing import Protocol, Self, TypeVar


ArrayT = TypeVar("ArrayT", bound="Array")


class ArrayNamespace(Protocol[ArrayT]):
    """Protocol for array namespaces."""

    pi: float

    def arange(self, n: int) -> ArrayT: ...

    def sqrt(self, x: ArrayT) -> ArrayT: ...
    def exp(self, x: ArrayT) -> ArrayT: ...
    def expm1(self, x: ArrayT) -> ArrayT: ...
    def log1p(self, x: ArrayT) -> ArrayT: ...


class Array(Protocol):
    """Protocol for arrays."""

    @property
    def shape(self) -> tuple[int, ...]: ...

    def __array_namespace__(self) -> ArrayNamespace[Self]: ...

    def __add__(self, other: Self | float) -> Self: ...
    def __sub__(self, other: Self | float) -> Self: ...
    def __mul__(self, other: Self | float) -> Self: ...
    def __truediv__(self, other: Self | float) -> Self: ...
    def __pow__(self, other: Self | float) -> Self: ...

    def __radd__(self, other: Self | float) -> Self: ...
    def __rsub__(self, other: Self | float) -> Self: ...
    def __rmul__(self, other: Self | float) -> Self: ...
    def __rtruediv__(self, other: Self | float) -> Self: ...
    def __rpow__(self, other: Self | float) -> Self: ...


## this checks with mypy --strict

import numpy as np
import jax.numpy as jnp

def exp(x: Array) -> Array:
    xp = x.__array_namespace__()
    return xp.exp(x)

exp(np.zeros(10))
exp(jnp.zeros(10))

ntessore avatar Apr 14 '25 15:04 ntessore

@ntessore sorry for taking too long. I discussed this issue with @paddyroddy -

We will have to add a signature for every Array API function used within glass to the Protocol. For instance, I get the following errors with the Protocols right now -

mypy.....................................................................Failed
- hook id: mypy
- exit code: 1

glass/algorithm.py:161: error: "ArrayNamespace[Array]" has no attribute "linalg"  [attr-defined]
glass/algorithm.py:165: error: "ArrayNamespace[Array]" has no attribute "finfo"  [attr-defined]
glass/algorithm.py:168: error: "ArrayNamespace[Array]" has no attribute "clip"  [attr-defined]
glass/algorithm.py:168: error: "ArrayNamespace[Array]" has no attribute "max"  [attr-defined]
glass/algorithm.py:173: error: "ArrayNamespace[Array]" has no attribute "matmul"  [attr-defined]
glass/algorithm.py:173: error: "ArrayNamespace[Array]" has no attribute "matrix_transpose"  [attr-defined]
glass/algorithm.py:207: error: "ArrayNamespace[Array]" has no attribute "linalg"  [attr-defined]
glass/algorithm.py:217: error: "ArrayNamespace[Array]" has no attribute "finfo"  [attr-defined]
glass/algorithm.py:217: error: "Array" has no attribute "dtype"  [attr-defined]
glass/algorithm.py:220: error: "ArrayNamespace[Array]" has no attribute "reshape"  [attr-defined]
glass/algorithm.py:223: error: "ArrayNamespace[Array]" has no attribute "zeros_like"  [attr-defined]
glass/algorithm.py:226: error: "ArrayNamespace[Array]" has no attribute "eye"  [attr-defined]
glass/algorithm.py:243: error: "ArrayNamespace[Array]" has no attribute "all"  [attr-defined]
glass/algorithm.py:247: error: "ArrayNamespace[Array]" has no attribute "reshape"  [attr-defined]
glass/algorithm.py:281: error: "ArrayNamespace[Array]" has no attribute "linalg"  [attr-defined]
glass/algorithm.py:284: error: "ArrayNamespace[Array]" has no attribute "any"  [attr-defined]
glass/algorithm.py:289: error: Incompatible types in assignment (expression has type "Array", variable has type "ndarray[tuple[int, ...], dtype[float64]]")  [assignment]
glass/algorithm.py:293: error: "ArrayNamespace[Array]" has no attribute "where"  [attr-defined]
glass/algorithm.py:293: error: "ArrayNamespace[Array]" has no attribute "asarray"  [attr-defined]
glass/fields.py:1075: error: Incompatible import of "cov_method" (imported name has type "Callable[[Array, float | None], Array]", local name has type "Callable[..., ndarray[tuple[int, ...], dtype[Any]]]")  [assignment]
glass/fields.py:1077: error: Incompatible import of "cov_method" (imported name has type "Callable[[Array, float | None, int], Array]", local name has type "Callable[..., ndarray[tuple[int, ...], dtype[Any]]]")  [assignment]
Found 21 errors in 2 files (checked 16 source files)

We can definitely do this, but I feel this approach is very unmaintainable. Moreover, the discussion in https://github.com/numpy/numpy/issues/28665 points out that NumPy's typing system is not built for this (no support for dtype, overloaded method working only if you're matching on the first overload, ...), and we should wait for array-api-typing to for more robust and maintainable type hints.

Saransh-cpp avatar Apr 25 '25 10:04 Saransh-cpp

Another rabbit hole that I went down today (unintentionally): Do we care about annotating the dtype of arrays? I was looking into creating a unified type for the arrays (as discussed in the meeting) and I came up with:

UnifiedArray: TypeAlias = NDArray[T] | Array

x: UnifiedArray[float] = <array>

but jax (jaxtyping) represents an array with a dtype as (where "..." is the unknown shape in this case):

Float[Array, "..."]

So, should our type just be:

UnifiedArray: TypeAlias = NDArray | Array

or should we not have a unified type altogether and just annotate functions with (I prefer this):

def something(arg: NDArray[float] | Float[Array, "..."]) -> SomeType:
def something(arg:  NDArray[complex] | Complex[Array, "..."]) -> SomeType:
...

Saransh-cpp avatar Apr 30 '25 15:04 Saransh-cpp

https://github.com/data-apis/array-api-typing

lucascolley avatar Jul 09 '25 21:07 lucascolley

Contributions welcome!

nstarman avatar Jul 09 '25 21:07 nstarman

I'm going to close this. We've settled on a consensus in glass._types. Once array-api-typing is ready we'll move to that.

paddyroddy avatar Nov 04 '25 16:11 paddyroddy