[BUG] Annotated Generics Not Behaving As Expected With Type Checking
Required prerequisites
- [x] I have read the documentation https://optree.readthedocs.io.
- [x] I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of OpTree are you using?
0.17
System information
3.13.1 (main, Dec 6 2024, 20:13:21) [Clang 18.1.8 ] darwin 0.17.0
Problem description
I would expect the use of PyTree with Generics would let me properly annotate functions where my LSP and/or type checker will work nicely with them.
Here is a very simple example:
import torch
from optree import PyTree, tree_map
def example(data: PyTree[torch.Tensor]) -> PyTree[int]:
def get_ndim(tensor: torch.Tensor) -> int:
return len(tensor.shape)
return tree_map(get_ndim, data)
data = {"a": torch.zeros((3, 4))}
result = example(data)
In the above code, for both pylance, pyright, and basedpyright, the definition of example is okay, but the final line `result
= example(data) yields the following diagnostic:
Argument of type "dict[str, Tensor]" cannot be assigned to parameter "data" of type "PyTree[Tensor]" in function "example"
"dict[str, Tensor]" is not assignable to "PyTree[Tensor]"Pylance[reportArgumentType](https://github.com/microsoft/pylance-release/blob/main/docs/diagnostics/reportArgumentType.md)
(variable) data: dict[str, Tensor]
Now, improved annotations were mentioned in this closed issue: https://github.com/metaopt/optree/issues/6, closed via this PR: https://github.com/metaopt/optree/pull/166
Based on the documentation here the example provided makes me believe my example should work, but it does not?
>>> import torch
>>> TensorTree = PyTree[torch.Tensor]
>>> TensorTree
typing.Union[torch.Tensor,
tuple[ForwardRef('PyTree[torch.Tensor]'), ...],
list[ForwardRef('PyTree[torch.Tensor]')],
dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')],
collections.deque[ForwardRef('PyTree[torch.Tensor]')],
optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]
My example creates a type dict[str, torch.Tensor], which it seems should work properly according to the above (which I confirmed on my machine I get the same output as the code above), yet all the LSPs seem to disagree with this as a valid input.
Reproducible example code
The Python snippets:
import torch
from optree import PyTree, tree_map
def example(data: PyTree[torch.Tensor]) -> PyTree[int]:
def get_ndim(tensor: torch.Tensor) -> int:
return len(tensor.shape)
return tree_map(get_ndim, data)
data = {"a": torch.zeros((3, 4))}
result = example(data)
Traceback
Expected behavior
I would expect that example(data) would not give me any diagnostic issues
Additional context
No response
For context, in case it is helpful, using the below instead of PyTree has no type issues:
from collections.abc import Mapping, Sequence
from typing import Any, Protocol, TypeAlias, TypeVar, Union, runtime_checkable
import torch
from optree import tree_map
@runtime_checkable
class MinimalArrayLike(Protocol):
shape: Any
dtype: Any
type Tree[T: MinimalArrayLike] = (
T | tuple["Tree[T]", ...] | Sequence["Tree[T]"] | Mapping[str, "Tree[T]"]
)
def example(data: Tree[torch.Tensor]) -> Tree[int]:
def get_ndim(tensor: torch.Tensor) -> int:
return tensor.ndim
return tree_map(get_ndim, data)
data = {"a": torch.zeros((3, 4))}
result = example(data) # Type checkers will be happy!
Thanks for the feedback! I will investigate this in a few days.
Note that Python is not like C/C++, which can declare the type first and use/implement it later.
// Forward declaration - define the type first
template<typename T>
struct PyTree;
// Define the recursive variant type - use it before PyTree is fully defined
template<typename T>
using PyTreeVariant = std::variant<
T, // Leaf value
std::vector<PyTree<T>>, // Vector of PyTrees
std::map<std::any, PyTree<T>> // Map of PyTrees (any key type)
>;
It is hard to make recursive generic types work in Python. I used a custom generic class (type checkers may not recognize it) with forward refs. The forward refs may not be evaluated correctly.
>>> import torch
>>> TensorTree = PyTree[torch.Tensor]
>>> TensorTree
typing.Union[torch.Tensor,
tuple[ForwardRef('PyTree[torch.Tensor]'), ...],
list[ForwardRef('PyTree[torch.Tensor]')],
dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')],
collections.deque[ForwardRef('PyTree[torch.Tensor]')],
optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]
The working example for me with pyright is:
from typing import TypeAlias
import torch
from optree import PyTreeTypeVar, tree_map
TensorTree: TypeAlias = PyTreeTypeVar('TensorTree', torch.Tensor) # pyright: ignore[reportInvalidTypeForm]
IntTree: TypeAlias = PyTreeTypeVar('IntTree', int) # pyright: ignore[reportInvalidTypeForm]
def example(data: TensorTree) -> IntTree:
def get_ndim(tensor: torch.Tensor) -> int:
return len(tensor.shape)
return tree_map(get_ndim, data)
data = {'a': torch.zeros((3, 4))}
result = example(data)
See also:
- #6
- python/mypy#731
- python/mypy#13693
- google/jax#3340
- Python Discussion: Generic
typing.ForwardRefto support generic recursive types