Type hinting for `eqx.Module` created under `vmap`
Hello, I have something like the following pattern
import equinox as eqx
from jaxtyping import Array, Float
#
# Library code
#
class TestModule(eqx.Module):
a: Float[Array, ""]
b: Float[Array, "2"]
def __init__(self, a, b):
...
def compute_something(vmapped_module: TestModule):
"""Specifically takes in TestModule with a batch dimension. How to type hint?"""
....
#
# Runtime code
#
import jax.random as jr
@eqx.filter_vmap
def make_module(a, b):
return TestModule(a, b)
dim = 10
key = jr.key(1234)
a = jr.random.normal(key, shape=(dim,))
b = jr.random.normal(key, shape=(dim, 2))
vmapped_module = make_module(a, b)
compute_something(vmapped_module)
How would I type hint the compute_something function, which specifically takes in a TestModule with particular batch dimensions?
So this isn't something we have a nice way to annotate in jaxtyping right now. I think supporting this is the same as https://github.com/patrick-kidger/jaxtyping/issues/242. I'm not sure how fiddly that would be, but if it doesn't require too much magic then it's something I'd be happy to take a PR on.
Aha, the solution in #242 is something like I had in mind. My only comment is that my modules can have non-arrays as leaves, so using filter_vmap under the hood for the type checking would be preferred.