jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Stateful Equinox Module: how to annotate?

Open EtaoinWu opened this issue 4 months ago • 2 comments

I recently come up with the following code:

from typing import Self

import equinox as eqx
from beartype import beartype
from jax import numpy as jnp
from jaxtyping import Array, Float, jaxtyped


@jaxtyped(typechecker=beartype) # to typecheck __init__
@beartype
class Accumulator(eqx.Module):
    x: Float[Array, " n"]

    @jaxtyped
    def add(self, y: Float[Array, " n"]) -> Self:
        return self.__class__(self.x + y)

Now, when running this code, jaxtyped complained in a UserWarning saying that it prefers the @jaxtyped(typechecker=beartype) syntax. (This warning was added before beartype's __instancecheck_str__ pseudostandard was implemented.) However, in this context, such syntax will lead to an error by beartype, because it lacks the context to figure out what typing.Self refers to. Therefore the code above is the only way to get it running.

However, this Accumulator faces an issue: If you write

@jaxtyped(typechecker=beartype)
def test_accumulator():
    x = jnp.ones(3)
    y = jnp.ones(4)
    acc1 = Accumulator(x)
    acc1 = acc1.add(x)
    acc2 = Accumulator(y)
    acc2 = acc2.add(y)
    return acc1, acc2

In calling acc2.add(y), it seems that n=3 is still in the memo from the previous acc1.add(x) call, and a type check error BeartypeCallHintParamViolation will be raised.

So, my question is: how do one properly type-annotate this kind of class?

EtaoinWu avatar Oct 03 '24 12:10 EtaoinWu