jax_dataclasses
jax_dataclasses copied to clipboard
Use jaxtyping to enrich type annotations
I just discovered the jaxtyping library and I think it could be an interesting alternative to the current typing system proposed by jax_dataclasses.
jaxtyping supports variable-size axes and symbolic expressions in terms of other variable-size axes, see https://github.com/google/jaxtyping/blob/main/API.md and it has very few requirements.
Do you think that it could be added to jax_dataclasses?
I've also been following jaxtyping
and wouldn't be opposed to deprecating the shape annotation syntax and standardizing on theirs!
Observations on my end:
- I haven't verified myself but for shape/datatype assertions, I think
jaxtyping
should work out-of-the-box with@jdc.pytree_dataclass
. (it seems like@jaxtyped
has some recent dataclass-related additions that might help?) - The main additional feature that the current
jdc.EnforcedAnnotationsMixin
gives you is a.get_batch_axes()
function. It seems likejaxtyping
doesn't have any public API for accessing shape annotations, so bringing this functionality over would require manually parsing the annotations.
Let me know if you more thoughts on this, or bandwidth for contributing. :slightly_smiling_face:
Surely the first point should be verified.
For the second point, I did some experiments and I think that typing.get_type_hints
is sufficient to access shape annotations.
I'll come back to you after more experiments and maybe with a pull request.
Looking forward to it, thanks @lucagrementieri!
From my experiments I think it's possible to support all the features using jaxtyping
with no manual parsing!
I think my pull request should be ready for next week.
I finally had time to push my PR https://github.com/brentyi/jax_dataclasses/pull/6 ! Sorry for the delay!