jaxtyping
jaxtyping copied to clipboard
Stateful Equinox Module: how to annotate?
I recently come up with the following code:
from typing import Self
import equinox as eqx
from beartype import beartype
from jax import numpy as jnp
from jaxtyping import Array, Float, jaxtyped
@jaxtyped(typechecker=beartype) # to typecheck __init__
@beartype
class Accumulator(eqx.Module):
x: Float[Array, " n"]
@jaxtyped
def add(self, y: Float[Array, " n"]) -> Self:
return self.__class__(self.x + y)
Now, when running this code, jaxtyped
complained in a UserWarning
saying that it prefers the @jaxtyped(typechecker=beartype)
syntax. (This warning was added before beartype's __instancecheck_str__
pseudostandard was implemented.) However, in this context, such syntax will lead to an error by beartype
, because it lacks the context to figure out what typing.Self
refers to. Therefore the code above is the only way to get it running.
However, this Accumulator
faces an issue: If you write
@jaxtyped(typechecker=beartype)
def test_accumulator():
x = jnp.ones(3)
y = jnp.ones(4)
acc1 = Accumulator(x)
acc1 = acc1.add(x)
acc2 = Accumulator(y)
acc2 = acc2.add(y)
return acc1, acc2
In calling acc2.add(y)
, it seems that n=3
is still in the memo from the previous acc1.add(x)
call, and a type check error BeartypeCallHintParamViolation
will be raised.
So, my question is: how do one properly type-annotate this kind of class?