jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Type hinting for `eqx.Module` created under `vmap`

Open michael-0brien opened this issue 8 months ago • 2 comments

Hello, I have something like the following pattern

import equinox as eqx
from jaxtyping import Array, Float

#
# Library code
#
class TestModule(eqx.Module):

    a: Float[Array, ""]
    b: Float[Array, "2"]

    def __init__(self, a, b):
        ...

def compute_something(vmapped_module: TestModule):
    """Specifically takes in TestModule with a batch dimension. How to type hint?"""
    ....

#
# Runtime code
#
import jax.random as jr

@eqx.filter_vmap
def make_module(a, b):
    return TestModule(a, b)

dim = 10
key = jr.key(1234)
a = jr.random.normal(key, shape=(dim,))
b = jr.random.normal(key, shape=(dim, 2))

vmapped_module  = make_module(a, b)
compute_something(vmapped_module)

How would I type hint the compute_something function, which specifically takes in a TestModule with particular batch dimensions?

michael-0brien avatar May 06 '25 15:05 michael-0brien

So this isn't something we have a nice way to annotate in jaxtyping right now. I think supporting this is the same as https://github.com/patrick-kidger/jaxtyping/issues/242. I'm not sure how fiddly that would be, but if it doesn't require too much magic then it's something I'd be happy to take a PR on.

patrick-kidger avatar May 06 '25 17:05 patrick-kidger

Aha, the solution in #242 is something like I had in mind. My only comment is that my modules can have non-arrays as leaves, so using filter_vmap under the hood for the type checking would be preferred.

michael-0brien avatar May 06 '25 18:05 michael-0brien