jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Batched type

Open joeryjoery opened this issue 2 years ago • 2 comments

Hi, I was wondering if a type-annotation for Batched datastructures exists or if it could be implemented?

So something similar along the likes of PyTree but for indicating that all leaf nodes have a leading axis size of e.g., N.

I know I can already do this easily with e.g., PyTree[Float32[Array, 'N ...']], but in this way I can't use my type-aliases (or don't know how to).

Example use-case:

from jaxtyping import Array, Float32, PyTree
from jax import vmap

MyArray = Float32[Array, '...']
Batched = ???  # Should prepend 'N ' to Float32[Array, 'N ...']

def sample_fun(x: PyTree[MyArray]) -> PyTree[MyArray]:
    return x

def batch_fun(xs: Batched[PyTree[MyArray]]) -> PyTree[Batched[MyArray]]:
    return jax.vmap(sample_fun)(xs)

Note, intuitively Batched[PyTree[...]] should be equivalent to PyTree[Batched[...]] as they should operate on the leaves.

joeryjoery avatar May 15 '23 14:05 joeryjoery

With a bit of trickery, this is possible!

from typing import TypeVar, Union, TYPE_CHECKING
from jaxtyping import Array, Float32, Shaped

if TYPE_CHECKING:  # needed for static type checking compatibility
    class _Unused:
        pass

    T = TypeVar("T")
    Batched = Union[T, _Unused]
else:
    class Batched:
        def __class_getitem__(cls, item):
            return Shaped[item, "N"]

This doesn't support Batched[PyTree[...]] though; you should write that as PyTree[Batched[...]]. (With a bit of hackery you could detect Batched[PyTree[...]] in __class_getitem__ and convert it, if you really wanted to.)

patrick-kidger avatar May 15 '23 15:05 patrick-kidger

With a bit of trickery, this is possible!

from typing import TypeVar, Union, TYPE_CHECKING
from jaxtyping import Array, Float32, Shaped

if TYPE_CHECKING:  # needed for static type checking compatibility
    class _Unused:
        pass

    T = TypeVar("T")
    Batched = Union[T, _Unused]
else:
    class Batched:
        def __class_getitem__(cls, item):
            return Shaped[item, "N"]

This doesn't support Batched[PyTree[...]] though; you should write that as PyTree[Batched[...]]. (With a bit of hackery you could detect Batched[PyTree[...]] in __class_getitem__ and convert it, if you really wanted to.)

Hi Patrick, thanks a lot for your quick reply! This looks great. I was playing around with this since I also have some PyTree[leaf] type-aliases, so I modified the class-getitem dunder by extracting the PyTree leaf if possible.

class Batched:
    def __class_getitem__(cls, item):
        if hasattr(item, 'leaftype'):
            return PyTree[Shaped[item.leaftype, 'N']]
        
        return Shaped[item, 'N']

joeryjoery avatar May 16 '23 06:05 joeryjoery