bug: can't type flax.struct.dataclass with vmapped functions
Hey, runtime type-checking seems to fail when providing a Flax dataclass to a vmapped function. I wasn't able to find related resources . Here is a minimal reproduction with the associated error.
import flax
import jax
import jax.numpy as jnp
from jaxtyping import Array
@flax.struct.dataclass
class Data:
a: Array
def f(x: Data) -> int:
return 1
data = Data(a=jnp.ones(1, dtype=int))
jax.vmap(f)(data)
It raises the following error (with beartyping):
E jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of Data. E The problem arose whilst typechecking argument 'a'. E Called with arguments: {'self': Data(...), 'a': <object object at 0x7fc7c87e8fc0>} E Parameter annotations: (self: Any, a: jax.Array).
Here are the versions I'm using:
flax==0.8.1 jax==0.4.21 jaxtyping==0.2.25
I tested, it works with chex.dataclass and equinox.Module, but I don't have the choice of using flax dataclasses in my case. Would love to find a workaround. Thanks!!
That's odd -- I've just tried running your code (with the same versions of each library) and don't see the same issue. Can you perhaps double-check in a new environment?
I see, my minimal reproduction was ambiguous. Sorry for that. I figured out that it depends on the order of the decorator @jaxtyped. This fails:
import beartype
import flax
import jax
import jax.numpy as jnp
from jaxtyping import Array, jaxtyped
@jaxtyped(typechecker=beartype.beartype)
@flax.struct.dataclass
class Data:
a: Array
def f(x: Data) -> int:
return 1
data = Data(a=jnp.ones(10, dtype=int))
jax.vmap(f)(data)
This doesn't fail:
import beartype
import flax
import jax
import jax.numpy as jnp
from jaxtyping import Array, jaxtyped
@flax.struct.dataclass
@jaxtyped(typechecker=beartype.beartype)
class Data:
a: Array
def f(x: Data) -> int:
return 1
data = Data(a=jnp.ones(10, dtype=int))
jax.vmap(f)(data)
I think the error occured for me because I used the pytest hook, that should add the jaxtyped decorator on top according to the docs.
Tell me if you can reproduce this :smile: (I have beartype==0.17.2 but I don't think it matters)
Ah, thank you!
It looks like this is a bug in Flax itself. Here's a MWE that doesn't use jaxtyping:
import flax
import jax.tree_util as jtu
@flax.struct.dataclass
class A:
x: int
def __init__(self):
pass
leaves, treedef = jtu.tree_flatten(A())
jtu.tree_unflatten(treedef, leaves)
It looks like the reason for this is that their tree-unflattening rule is using the __init__ method for their type, which is a long-standing gotcha when using JAX.
Unsurprisingly, I'd recommend using Equinox instead :)