equinox
equinox copied to clipboard
Type annotations for "struct of arrays" ?
I love equinox
and jaxtyping
and would like to figure out a nice way to annotate stacked or batched eqx.Modules
but so far I've failed. For example, I'll often do something like the following:
class SE3(eqx.Module):
U: Float[Array, "3 3"]
t: Float[Array, "3"]
# probably wouldn't construct this in this way but it gets the point across
batched_elements = jtu.map(lambda *v: jnp.stack(v), *[SE3(jnp.eye(3), jnp.zeros(3)) for _ in range(10)])
# ridiculous function but also useful to get the point across!
def sum_translations(batch: ???):
return vmap(lambda F: F.t)(batch).sum()
sum_translations(batched_elements)
Is there a reasonable type annotation for batch
above? I would love to be able to say something like SE3["n"]
, possibly by modifying the SE3
class.
Thanks so much, Nick
p.s. I think there's a chance this is related to https://github.com/patrick-kidger/jaxtyping/issues/84 but I have to admit I don't understand python well enough to figure this out!