jax_dataclasses
jax_dataclasses copied to clipboard
Add support for jaxtyping
This PR adds support for jaxtyping annotations preserving all the features and checks on tensor dimensions.
The PR doesn't update the README, since it could become messy very easily. I'll wait further indications to update the README.
Close #5.
Thanks! Looks reasonable overall, my main concern is the private jaxtyping
imports. I assume there's no way to get around this?
_MetaAbstractArray
is the base class of all jaxtyping
types and annotations so the check isinstance(type_hint, _MetaAbstractArray)
is the best way to identify jaxtyping
annotations. Surely there are workarounds to not use it, but they will be more fragile and less elegant.
For _NamedVariadicDim
, I think there is no simple workaround because this class is required to support variadic dimensions, like the batch dimension.
Okay, makes sense! It's definitely not ideal but having support for jaxtyping
here seems useful enough to warrant it. I like how we don't have to worry about import hooks or @jaxtyped
for ~~the shape checks~~ checking/getting the batch axes.
(cc @patrick-kidger for any warnings, are there any plans to rework the internals of jaxtyping
?)
I can handle the rest of the PR. Some TODOs would be:
- [ ] Consider adding a
jaxtyping
install dependency with a pinned version. - [ ] Fixing the tests for Python 3.7. (seems unrelated to this PR)
- [ ] Updating the README.
- [ ] Deprecation warnings for the proprietary shape typing interface.
- [ ] Consider renaming
EnforcedAnnotationsMixin
while we're at it. This is a mouthful.
There aren't any current plans. But taking a quick glance at your code, I think this will fail for a type hint of the form tuple[Float[Array, "foo"], ...]
, i.e. one in which the array is nested within another type hint? (I didn't check that closely though, so I might be wrong.)
Anyway, jaxtyping hints are expected to be validated using a runtime type checker, such as typeguard or beartype. I'd recommend that you simply do the same thing, as they'll handle the details for you: both the nesting above, and avoiding the need to access private jaxtyping functionality.
Side note: if you're working on a project like this then you may find Equinox interesting. In particular equinox.Module
is also a dataclass-pytree combo, with most (all?) of the expected bells-and-whistles: serialisation, immutability etc. (I suppose I've not really tested static type checking, as I'm mostly a non-user of that.)
I like the neat syntax of your copy_and_mutate
, by the way. (Equinox's equivalent is equinox.tree_at
, which is safer but a bit harder to use.)
Thanks!
I've also been following Equinox; definitely the "how to build pytrees" + tooling compatibility landscapes have improved quite a bit since I started jax_dataclasses
in ~late 2020. For now I think mypy
compatibility + the Static[]
API + copy_and_mutate
are still nice enough for me to keep the library around, but sooner or later I should revisit whether the library still makes sense given developments in equinox, flax, etc (especially if a copy_and_mutate
-style API is merged into flax
https://github.com/google/flax/pull/2735).
I also agree that typeguard
or beartype
makes sense for asserting that the shapes are correct, but for this PR the main purpose is to support jaxtyping
annotations for the dataclass.get_batch_axes()
helper that currently works for jax_dataclasses
-proprietary shape annotations.
For this we need to figure out which axes in the array shapes correspond to the variadic dimension, which leaves the options of: (a) touching the private bits of jaxtyping, (b) trying to convince @patrick-kidger to expose a public API for reasoning about jaxtyping types*, or (c) not implementing this functionality.
*maybe something like: (jaxtyping type, array) -> labels for each axis in the array. Any chance you're open to something like this? (understand if not)
You should be able to replace isinstance(type_hint, _MetaAbstractArray)
with issubclass(type_hint, AbstractArray)
. (Which is public API.)
At that point I can see that you'd want to modify its dimensions. I think the best way to do this would be to submit a PR against jaxtyping that records cls
and item
here:
https://github.com/google/jaxtyping/blob/59e8fb0d18325f990a9d59ee35e90c04b699cab8/jaxtyping/array_types.py#L400
so that you can then look these up, modify these as desired, and then recreate the type hint through the public jaxtyping API (e.g. cls[item]
to recreate the same hint).