How to specify per leaf shape of a pytree?
How to specify exact per leaf shape of a pytree, say a dict? And even further, the graph structure of a pytree.
For example, function f takes a dict as input:
@jaxtyped(typechecker=beartype)
def f(state: Dict{'x': Float[Array, 'b 10'], 'y': Float[Array, 'b 1']}):
...
# Example valid input
valid_state = {
'x': jnp.ones((3, 10)), # b=3
'y': jnp.zeros((3, 1)) # b=3, consistent
}
# Example invalid input (wrong shape for 'x')
invalid_state = {
'x': jnp.ones((3, 99)), # Shape is not 'b 10'
'y': jnp.zeros((3, 1))
}
f(valid_state)
try:
f(invalid_state)
except Exception as e:
print(f"\nError with invalid_state:\n{e}")
(The above snippet is not going to work)
Hope the type checker can check every leaf's shape and the graph structure.
PyTree[Float[Array, 'b ...']] is good but not fine-grained.
I think this feature is quite intuitive, e.g., in RL, jax env's step function takes a complex state. Maybe there is ways or workaround but I failed to find one. Sorry for possible ignorance.
Are you looking for something like this?
This also works by creating a type alias for T in the example above, with
from typing import TypeAlias
T: TypeAlias = PyTree[Shaped[Array, "?foo"], "T"]
which lets you use T directly as a type annotation. This also works if you define a custom structure for your complex states, e.g. with an Equinox module
import equinox as eqx
from jaxtyping import Array
class ComplexState(eqx.Module):
x: dict[str, Array]
...
in this case you can use ComplexState as a type annotation.
Are you looking for something like this?
This also works by creating a type alias for
Tin the example above, withfrom typing import TypeAlias
T: TypeAlias = PyTree[Shaped[Array, "?foo"], "T"] which lets you use
Tdirectly as a type annotation. This also works if you define a custom structure for your complex states, e.g. with an Equinox moduleimport equinox as eqx from jaxtyping import Array
class ComplexState(eqx.Module): x: dict[str, Array] ... in this case you can use
ComplexStateas a type annotation.
I'am afraid this is not what I'm looking for. Path-dependent-shapes is about keeping two strurctured pytree parameters with consistent shape w.r.t. every leaf. What I am looking for is to specify shapes of every leaf of one pytree while the shapes of every leaf is not the same. As for
... define a custom structure for your complex states ...
I can't see how to specify the shape in the ComlexState def. Can you elaborate on this?
Have found a workaround by using dataclass
from jaxtyping import Float, jaxtyped
from beartype import beartype
from dataclasses import dataclass
@jaxtyped(typechecker=beartype)
@dataclass
class State:
x: Shaped[Array, "B F"]
y: Shaped[Array, "B"]
@jaxtyped(typechecker=beartype)
def fn(state: State):
pass
valid_state = State(x=jnp.ones((2, 3)), y=jnp.ones((2,)))
fn(valid_state)
try:
invalid_state = State(x=jnp.ones((2, 3)), y=jnp.ones((2, 1)))
fn(invalid_state)
except Exception as e:
print(f"Error: {e}")
Finally get NamedTuple work. But it is strange that the class should be decorated with @beartype rather than @jaxtyped(typechecker=beartype).
from typing import TypedDict, NamedTuple
from jaxtyping import Float, jaxtyped, Shaped, Array
import jax.numpy as jnp
from beartype import beartype
from dataclasses import dataclass
# @jaxtyped(typechecker=beartype)
@beartype
class State(NamedTuple):
x: Shaped[Array, "B F"]
y: Shaped[Array, "B"]
@jaxtyped(typechecker=beartype)
# @beartype
def fn(state: State):
pass
valid_state = State(x=jnp.ones((2, 3)), y=jnp.ones((2,)))
fn(valid_state)
try:
invalid_state = State(x=jnp.ones((2, 3)), y=jnp.ones((2, 1)))
fn(invalid_state)
except Exception as e:
print(f"Error: {e}")
And it is not clear that why TypedDict not work, i.e. substitue the class definition with
class State(TypedDict):
x: Shaped[Array, "B F"]
y: Shaped[Array, "B"]
So I think your example with dataclasses is succesfully raising an error because you have the initialisation of the dataclass within your try/except region, and it is this which fails. (And an Equinox Module, which is a dataclass, would likewise work just as well here.)
On why namedtuples and typeddicts have their kinds of behaviour - ultimately, this all up to the choice of typechecker (here beartype) to process these however they like.
As for the central nature of your question: jaxtyping works across a single call site. At that call site, it will establish an environment for recording the size of each axis, then call the typechecker on each argument. Every time the typechecker performs an isinstance check against a jaxtyping annotation, the relevant sizes are recorded into this environment.
What that means is that if you want to ensure that a structured type has its axis sizes match those of the callsite, then you need to ensure that when the typechecker interacts with it, that it will end performing those instance checks.
If you'd like to set up something like that then take a look at #328.
So I think your example with dataclasses is succesfully raising an error because you have the initialisation of the dataclass within your try/except region, and it is this which fails. (And an Equinox Module, which is a dataclass, would likewise work just as well here.)
On why namedtuples and typeddicts have their kinds of behaviour - ultimately, this all up to the choice of typechecker (here beartype) to process these however they like.
As for the central nature of your question: jaxtyping works across a single call site. At that call site, it will establish an environment for recording the size of each axis, then call the typechecker on each argument. Every time the
typecheckerperforms anisinstancecheck against a jaxtyping annotation, the relevant sizes are recorded into this environment.What that means is that if you want to ensure that a structured type has its axis sizes match those of the callsite, then you need to ensure that when the typechecker interacts with it, that it will end performing those
instancechecks.If you'd like to set up something like that then take a look at #328.
Thanks for your quick reply! You're absolutely right that I the typecheck happens at the variable initializetion not the function execution.
Things seem to be complicated related to #328. Although different from the original desire, it is enough to have type check on initialization.