jax_dataclasses icon indicating copy to clipboard operation
jax_dataclasses copied to clipboard

Use jaxtyping to enrich type annotations

Open lucagrementieri opened this issue 2 years ago • 5 comments

I just discovered the jaxtyping library and I think it could be an interesting alternative to the current typing system proposed by jax_dataclasses.

jaxtyping supports variable-size axes and symbolic expressions in terms of other variable-size axes, see https://github.com/google/jaxtyping/blob/main/API.md and it has very few requirements.

Do you think that it could be added to jax_dataclasses?

lucagrementieri avatar Nov 17 '22 16:11 lucagrementieri

I've also been following jaxtyping and wouldn't be opposed to deprecating the shape annotation syntax and standardizing on theirs!

Observations on my end:

  • I haven't verified myself but for shape/datatype assertions, I think jaxtyping should work out-of-the-box with @jdc.pytree_dataclass. (it seems like @jaxtyped has some recent dataclass-related additions that might help?)
  • The main additional feature that the current jdc.EnforcedAnnotationsMixin gives you is a .get_batch_axes() function. It seems like jaxtyping doesn't have any public API for accessing shape annotations, so bringing this functionality over would require manually parsing the annotations.

Let me know if you more thoughts on this, or bandwidth for contributing. :slightly_smiling_face:

brentyi avatar Nov 17 '22 19:11 brentyi

Surely the first point should be verified. For the second point, I did some experiments and I think that typing.get_type_hints is sufficient to access shape annotations. I'll come back to you after more experiments and maybe with a pull request.

lucagrementieri avatar Nov 18 '22 16:11 lucagrementieri

Looking forward to it, thanks @lucagrementieri!

brentyi avatar Nov 18 '22 22:11 brentyi

From my experiments I think it's possible to support all the features using jaxtyping with no manual parsing!

I think my pull request should be ready for next week.

lucagrementieri avatar Nov 25 '22 07:11 lucagrementieri

I finally had time to push my PR https://github.com/brentyi/jax_dataclasses/pull/6 ! Sorry for the delay!

lucagrementieri avatar Jan 07 '23 17:01 lucagrementieri