optree icon indicating copy to clipboard operation
optree copied to clipboard

[BUG] Annotated Generics Not Behaving As Expected With Type Checking

Open RyanSaxe opened this issue 1 month ago • 2 comments

Required prerequisites

What version of OpTree are you using?

0.17

System information

3.13.1 (main, Dec 6 2024, 20:13:21) [Clang 18.1.8 ] darwin 0.17.0

Problem description

I would expect the use of PyTree with Generics would let me properly annotate functions where my LSP and/or type checker will work nicely with them.

Here is a very simple example:

import torch
from optree import PyTree, tree_map


def example(data: PyTree[torch.Tensor]) -> PyTree[int]:

    def get_ndim(tensor: torch.Tensor) -> int:
        return len(tensor.shape)

    return tree_map(get_ndim, data)


data = {"a": torch.zeros((3, 4))}
result = example(data)

In the above code, for both pylance, pyright, and basedpyright, the definition of example is okay, but the final line `result = example(data) yields the following diagnostic:

Argument of type "dict[str, Tensor]" cannot be assigned to parameter "data" of type "PyTree[Tensor]" in function "example"
  "dict[str, Tensor]" is not assignable to "PyTree[Tensor]"Pylance[reportArgumentType](https://github.com/microsoft/pylance-release/blob/main/docs/diagnostics/reportArgumentType.md)
(variable) data: dict[str, Tensor]

Now, improved annotations were mentioned in this closed issue: https://github.com/metaopt/optree/issues/6, closed via this PR: https://github.com/metaopt/optree/pull/166

Based on the documentation here the example provided makes me believe my example should work, but it does not?

>>> import torch
>>> TensorTree = PyTree[torch.Tensor]
>>> TensorTree
typing.Union[torch.Tensor,
             tuple[ForwardRef('PyTree[torch.Tensor]'), ...],
             list[ForwardRef('PyTree[torch.Tensor]')],
             dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')],
             collections.deque[ForwardRef('PyTree[torch.Tensor]')],
             optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]

My example creates a type dict[str, torch.Tensor], which it seems should work properly according to the above (which I confirmed on my machine I get the same output as the code above), yet all the LSPs seem to disagree with this as a valid input.

Reproducible example code

The Python snippets:

import torch
from optree import PyTree, tree_map


def example(data: PyTree[torch.Tensor]) -> PyTree[int]:

    def get_ndim(tensor: torch.Tensor) -> int:
        return len(tensor.shape)

    return tree_map(get_ndim, data)


data = {"a": torch.zeros((3, 4))}
result = example(data)

Traceback


Expected behavior

I would expect that example(data) would not give me any diagnostic issues

Additional context

No response

RyanSaxe avatar Nov 05 '25 22:11 RyanSaxe

For context, in case it is helpful, using the below instead of PyTree has no type issues:

from collections.abc import Mapping, Sequence
from typing import Any, Protocol, TypeAlias, TypeVar, Union, runtime_checkable

import torch
from optree import tree_map


@runtime_checkable
class MinimalArrayLike(Protocol):
    shape: Any
    dtype: Any


type Tree[T: MinimalArrayLike] = (
    T | tuple["Tree[T]", ...] | Sequence["Tree[T]"] | Mapping[str, "Tree[T]"]
)


def example(data: Tree[torch.Tensor]) -> Tree[int]:
    def get_ndim(tensor: torch.Tensor) -> int:
        return tensor.ndim

    return tree_map(get_ndim, data)


data = {"a": torch.zeros((3, 4))}
result = example(data)  # Type checkers will be happy!

RyanSaxe avatar Nov 05 '25 22:11 RyanSaxe

Thanks for the feedback! I will investigate this in a few days.

Note that Python is not like C/C++, which can declare the type first and use/implement it later.

// Forward declaration - define the type first
template<typename T>
struct PyTree;

// Define the recursive variant type - use it before PyTree is fully defined
template<typename T>
using PyTreeVariant = std::variant<
    T,                             // Leaf value
    std::vector<PyTree<T>>,        // Vector of PyTrees
    std::map<std::any, PyTree<T>>  // Map of PyTrees (any key type)
>;

It is hard to make recursive generic types work in Python. I used a custom generic class (type checkers may not recognize it) with forward refs. The forward refs may not be evaluated correctly.

>>> import torch
>>> TensorTree = PyTree[torch.Tensor]
>>> TensorTree
typing.Union[torch.Tensor,
             tuple[ForwardRef('PyTree[torch.Tensor]'), ...],
             list[ForwardRef('PyTree[torch.Tensor]')],
             dict[typing.Any, ForwardRef('PyTree[torch.Tensor]')],
             collections.deque[ForwardRef('PyTree[torch.Tensor]')],
             optree.typing.CustomTreeNode[ForwardRef('PyTree[torch.Tensor]')]]

The working example for me with pyright is:

from typing import TypeAlias

import torch
from optree import PyTreeTypeVar, tree_map


TensorTree: TypeAlias = PyTreeTypeVar('TensorTree', torch.Tensor)  # pyright: ignore[reportInvalidTypeForm]
IntTree: TypeAlias = PyTreeTypeVar('IntTree', int)  # pyright: ignore[reportInvalidTypeForm]


def example(data: TensorTree) -> IntTree:
    def get_ndim(tensor: torch.Tensor) -> int:
        return len(tensor.shape)

    return tree_map(get_ndim, data)


data = {'a': torch.zeros((3, 4))}
result = example(data)

See also:

XuehaiPan avatar Nov 06 '25 06:11 XuehaiPan