jaxtyping
jaxtyping copied to clipboard
Better type hint for `PyTree`
# Set up to deliberately confuse a static type checker.
PyTree: TypeAlias = getattr(typing, "foo" + "bar")
# What's going on with this madness?
#
# At static-type-checking-time, we want `PyTree` to be a type for which both
# `PyTree` and `PyTree[Foo]` are equivalent to `Any`.
# (The intention is that `PyTree` be a runtime-only type; there's no real way to
# do more with static type checkers.)
I think this can be improved using python 3.13 features:
from typing import Any
type PyTree[T=Any] = Any
Code sample in pyright playground
from typing import Any
type PyTree[T=Any] = Any
foo: PyTree[int] = "aga"
bar: PyTree = 123
def baz(arg: PyTree, other: PyTree[int]):
reveal_type(arg) # Any
reveal_type(other) # Any
reveal_type(foo) # Any
reveal_type(bar) # Any
Note: mypy-playground currently throws syntax error for some reason. Running it locally with python 3.13.2 and mypy 1.15.0 works fine.
Thanks for the heads-up! As this uses new syntactic features I suspect we'll need to wait until 3.12 is EOL before we can use this.