xarray
xarray copied to clipboard
mypy does not understand output of binary operations
What happened?
When doing operations on numpy arrays and xarray variables mypy does not understand that the output is always a xarray variable regardless of the order. See example.
What did you expect to happen?
mypy to pass for the example code.
Minimal Complete Verifiable Example
import numpy as np
import xarray as xr
x = np.array([1, 2, 4])
v = xr.Variable(["x"], x)
# numpy first:
xv = x * v
xv.values # error: "ndarray[Any, dtype[bool_]]" has no attribute "values" [attr-defined]
if isinstance(xv, xr.Variable):
xv.values
# variable first:
vx = v * x
vx.values
if isinstance(vx, xr.Variable):
vx.values
MVCE confirmation
- [X] Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
- [X] Complete example — the example is self-contained, including all data and the text of any traceback.
- [X] Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
- [X] New issue — a search of GitHub Issues suggests this is not a duplicate.
Relevant log output
No response
Anything else we need to know?
Seen in #7741
Environment
xr.show_versions()
INSTALLED VERSIONS
commit: None python: 3.9.16 (main, Mar 8 2023, 10:39:24) [MSC v.1916 64 bit (AMD64)] python-bits: 64 OS: Windows OS-release: 10 machine: AMD64 processor: Intel64 Family 6 Model 58 Stepping 9, GenuineIntel byteorder: little LC_ALL: None LANG: en libhdf5: 1.10.6 libnetcdf: None
xarray: 2023.4.2 pandas: 2.0.0 numpy: 1.23.5 scipy: 1.10.1 netCDF4: None pydap: None h5netcdf: None h5py: 2.10.0 Nio: None zarr: None cftime: None nc_time_axis: None PseudoNetCDF: None iris: None bottleneck: None dask: 2023.4.0 distributed: 2023.4.0 matplotlib: 3.5.3 cartopy: None seaborn: 0.12.2 numbagg: None fsspec: 2023.4.0 cupy: None pint: None sparse: None flox: None numpy_groupies: None setuptools: 67.7.1 pip: 23.1.1 conda: 23.3.1 pytest: 7.3.1 mypy: 1.2.0 IPython: 8.12.0 sphinx: 6.1.3
That's interesting.
I just realized that the _typed_ops.pyi type definitions are wrong.
e.g. DataArray.__add__ will return NotImplemented on Datasets which in turn should then call Dataset.__radd__.
In type definitions, it should simply not accept Datasets at all, and mypy will realize this and use Dataset.__radd__ instead.
So all these overloads are useless.
But not sure if this solves the problem, I have not tested it yet!
In a mini test example this worked for me:
class A:
i: int
def __init__(self, i: int) -> None:
self.i = i
def __add__(self, other: A | int) -> A:
if isinstance(other, A):
return A(self.i + other.i)
if isinstance(other, int):
return A(self.i + other)
return NotImplemented
def __radd__(self, other: A | int) -> A:
return self + other
def __str__(self):
return f"A({self.i})"
class B:
ii: dict[str, int]
def __init__(self, ii: dict[str, int]) -> None:
self.ii = ii
def __add__(self, other: A | B | int) -> B:
if isinstance(other, B):
return B(
{k: i + other.ii.get(k, 0) for k, i in self.ii.items()}
| {k: o for k, o in other.ii.items() if k not in self.ii}
)
if isinstance(other, A):
return B({k: i + other.i for k, i in self.ii.items()})
if isinstance(other, int):
return B({k: i + other for k, i in self.ii.items()})
return NotImplemented
def __radd__(self, other: A | B | int) -> B:
return self + other
def __str__(self):
return f"B({self.ii})"
a = A(1)
b = B({"a": 5})
print(a + 1)
print(1 + a)
print(a + a)
print(b + 5)
print(5 + b)
print(b + b)
print(a + b)
print(b + a)
reveal_type(a + 1)
reveal_type(1 + a)
reveal_type(a + a)
reveal_type(b + 5)
reveal_type(5 + b)
reveal_type(b + b)
reveal_type(a + b)
reveal_type(b + a)
Note how A does not know anything about B in code and in typing.
I have entered quite the rabbit hole here. This is not only a typing issue anymore :/
Somehow the operators used in xarray are not doing what the python standard dictates: return NotImplemented when you get an unknown type. Instead we pass everything down to the underlying data types.
This results in weird error messages like:
import xarray as xr
xr.DataArray([1, 2, 3]) + {"asd"}
raises:
TypeError: unsupported operand type(s) for +: 'int' and 'set'
but it should be "TypeError: unsupported operand type(s) for +: 'xarray.core.DataArray' and 'set'"
It seems that we have done this only for xarray objects but not all other types.
Back to static typing:
I still don't know how to best type the other argument of these ops...
In principle this method is supposed to accept ANYTHING, so the type should be object and then I might have to add some overloads for the compatible data and their return types.
However, -> NotImplemented or Literal[NotImplemented] or similar does not work. It seems that mypy is happy if I remove this overload entirely (i.e. do not use other: object at all and only use the compatible one), I only get some complaints about the Mapping class which could be ignored.
Ok, I could improve the typing a bit of the ops but I fear that numpy is too dominant here...
Numpy types its __add__ with a type that includes SupportsArray, a Protocol that defines an __array__ method.
In other words: a xarray.Variable falls into this type and the type checker will always think that np.ndarray.__add__(self, xr.Variable(...)) -> np.ndarray just as you wrote.
So this particular problem is a numpy upstream problem :/ Feel free to open a PR over there, but I don't see a quick fix for this on numpys side...
I coincidentally came across this today — I think it might be the same issue, but can open a new issue if not.
With this code:
diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py
index 3c40d0a2..193cde41 100644
--- a/xarray/tests/test_variable.py
+++ b/xarray/tests/test_variable.py
@@ -26,6 +26,7 @@
VectorizedIndexer,
)
from xarray.core.pycompat import array_type
+from xarray.core.types import T_DataWithCoords, T_Xarray
from xarray.core.utils import NDArrayMixin
from xarray.core.variable import as_compatible_data, as_variable
from xarray.tests import (
@@ -395,6 +396,9 @@ def test_1d_math(self, dtype: np.typing.DTypeLike) -> None:
assert isinstance(0 + v, Variable)
assert not isinstance(0 + v, IndexVariable)
+ def test_type_pow(self, x: T_DataWithCoords) -> T_DataWithCoords:
+ return x**2
+
def test_1d_reduce(self):
x = np.arange(5)
v = self.cls(["x"], x)
We get this error:
xarray/tests/test_variable.py: note: In member "test_type_pow" of class "VariableSubclassobjects":
xarray/tests/test_variable.py:400: error: Unsupported operand types for ** ("T_DataWithCoords" and "int")
I coincidentally came across this today — I think it might be the same issue, but can open a new issue if not.
With this code:
diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 3c40d0a2..193cde41 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -26,6 +26,7 @@ VectorizedIndexer, ) from xarray.core.pycompat import array_type +from xarray.core.types import T_DataWithCoords, T_Xarray from xarray.core.utils import NDArrayMixin from xarray.core.variable import as_compatible_data, as_variable from xarray.tests import ( @@ -395,6 +396,9 @@ def test_1d_math(self, dtype: np.typing.DTypeLike) -> None: assert isinstance(0 + v, Variable) assert not isinstance(0 + v, IndexVariable) + def test_type_pow(self, x: T_DataWithCoords) -> T_DataWithCoords: + return x**2 + def test_1d_reduce(self): x = np.arange(5) v = self.cls(["x"], x)We get this error:
xarray/tests/test_variable.py: note: In member "test_type_pow" of class "VariableSubclassobjects": xarray/tests/test_variable.py:400: error: Unsupported operand types for ** ("T_DataWithCoords" and "int")
This might be because the typed_ops is only added into the class hierarchy at the Variable, DataArray and Dataset level, but not for DataWithCoords.
If you change that to T_Xarray then it should work.
Yes, you're completely correct — thanks!
I'll minimize these to reduce the noise
Related:
[1, 2 ,3] == DataArray([1, 2, 3], dims=["t"])
will be inferred to bool. Same for anything basically, including np.ndarray. There is no way to prevent this currently. Requires a typeshed update.
Caused by https://github.com/python/typeshed/issues/8217 / https://discuss.python.org/t/make-type-hints-for-eq-of-primitives-less-strict/34240
Is there anything we can do here or should we close?