equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Type annotations for "struct of arrays" ?

Open nboyd opened this issue 3 months ago • 1 comments

I love equinox and jaxtyping and would like to figure out a nice way to annotate stacked or batched eqx.Modules but so far I've failed. For example, I'll often do something like the following:

class SE3(eqx.Module):
    U: Float[Array, "3 3"]
    t: Float[Array, "3"]

# probably wouldn't construct this in this way but it gets the point across
batched_elements = jtu.map(lambda *v: jnp.stack(v), *[SE3(jnp.eye(3), jnp.zeros(3)) for _ in range(10)])

# ridiculous function but also useful to get the point across! 
def sum_translations(batch: ???):
    return vmap(lambda F: F.t)(batch).sum()

sum_translations(batched_elements)

Is there a reasonable type annotation for batch above? I would love to be able to say something like SE3["n"], possibly by modifying the SE3 class.

Thanks so much, Nick

p.s. I think there's a chance this is related to https://github.com/patrick-kidger/jaxtyping/issues/84 but I have to admit I don't understand python well enough to figure this out!

nboyd avatar Mar 25 '24 20:03 nboyd