Batched type
Hi, I was wondering if a type-annotation for Batched datastructures exists or if it could be implemented?
So something similar along the likes of PyTree but for indicating that all leaf nodes have a leading axis size of e.g., N.
I know I can already do this easily with e.g., PyTree[Float32[Array, 'N ...']], but in this way I can't use my type-aliases (or don't know how to).
Example use-case:
from jaxtyping import Array, Float32, PyTree
from jax import vmap
MyArray = Float32[Array, '...']
Batched = ??? # Should prepend 'N ' to Float32[Array, 'N ...']
def sample_fun(x: PyTree[MyArray]) -> PyTree[MyArray]:
return x
def batch_fun(xs: Batched[PyTree[MyArray]]) -> PyTree[Batched[MyArray]]:
return jax.vmap(sample_fun)(xs)
Note, intuitively Batched[PyTree[...]] should be equivalent to PyTree[Batched[...]] as they should operate on the leaves.
With a bit of trickery, this is possible!
from typing import TypeVar, Union, TYPE_CHECKING
from jaxtyping import Array, Float32, Shaped
if TYPE_CHECKING: # needed for static type checking compatibility
class _Unused:
pass
T = TypeVar("T")
Batched = Union[T, _Unused]
else:
class Batched:
def __class_getitem__(cls, item):
return Shaped[item, "N"]
This doesn't support Batched[PyTree[...]] though; you should write that as PyTree[Batched[...]]. (With a bit of hackery you could detect Batched[PyTree[...]] in __class_getitem__ and convert it, if you really wanted to.)
With a bit of trickery, this is possible!
from typing import TypeVar, Union, TYPE_CHECKING from jaxtyping import Array, Float32, Shaped if TYPE_CHECKING: # needed for static type checking compatibility class _Unused: pass T = TypeVar("T") Batched = Union[T, _Unused] else: class Batched: def __class_getitem__(cls, item): return Shaped[item, "N"]This doesn't support
Batched[PyTree[...]]though; you should write that asPyTree[Batched[...]]. (With a bit of hackery you could detectBatched[PyTree[...]]in__class_getitem__and convert it, if you really wanted to.)
Hi Patrick, thanks a lot for your quick reply! This looks great. I was playing around with this since I also have some PyTree[leaf] type-aliases, so I modified the class-getitem dunder by extracting the PyTree leaf if possible.
class Batched:
def __class_getitem__(cls, item):
if hasattr(item, 'leaftype'):
return PyTree[Shaped[item.leaftype, 'N']]
return Shaped[item, 'N']