Jaxtyping annotations don't work with pyserde
I'm using Jaxtyping to make sure that my tensors are in the correct shape. However, I also need to serialise the tensors to an API, for which I'm using pyserde with a custom global serialiser.
If the types of the tensor are plain torch.Tensor the custom serialisers works and I can serialise the classes to JSON or other pyserde supported formats. However, when I annotate the tensors with Jaxtyping, e.g. Float32[torch.Tensor, "4"] the serialiser fails with the error:
serde.compat.SerdeError: Unsupported type: Tensor
Minimal (not) working example, uncomment the field to get the error:
from typing import Annotated, Any
from plum import dispatch
import torch
from jaxtyping import Float32
import serde
from serde.json import to_json
class Serializer:
@dispatch
def serialize(self, value: torch.Tensor) -> Any:
return {
"__tensor__": True,
"dtype": str(value.dtype),
"shape": list(value.shape),
"data": value.cpu().tolist(),
}
serde.add_serializer(Serializer())
@serde.serde
class Foo:
tensor_works: torch.Tensor
annotated_works: Annotated[torch.Tensor, "4"]
# jax: Float32[torch.Tensor, "4"]
foo = Foo(
tensor_works=torch.tensor([1000.0, 2000.0, 3000.0]),
annotated_works=torch.tensor([0.0, 0.0, 0.0, 1.0], dtype=torch.float32),
# jax=torch.tensor([100.0, 0.0, 0.0, 0.0], dtype=torch.float32), <- Doesn't work
)
j = to_json(foo)
print(j)
This is a feature that I'd be happy to take a PR on :)
FWIW we currently test pickle and cloudpickle for peforming serialisation: https://github.com/patrick-kidger/jaxtyping/blob/main/test/test_serialisation.py . If you can identify some minimal set of tweaks necessary for us to work with pyserde (+add a test), then I'd be happy to make those.
The problem, as far as I can tell, is that for typecheckers torch.Tensor and something like Float32[torch.Tensor, "4"] are two different types, even though they're the same at runtime. What's bugging me is that when I use VSCode to navigate to the definition of Float32 it directs me to the builtin Annotated class, for which serialisation works.
If I manually set the same serialisation logic as a field serialiser, i.e. = field(serializer=tensor_serialiser) then the serialisation works as expected. However, this means that I'd have to put this annotation in multiple dozen fields, hence the global serialiser.
The problem, as far as I can tell, is that for typecheckers
torch.Tensorand something likeFloat32[torch.Tensor, "4"]are two different types, even though they're the same at runtime.
Actually, they're different objects at runtime! The jaxtyping annotation is a custom type for which isinstance(...) will perform the shape/dtype checks.
What's bugging me is that when I use VSCode to navigate to the definition of Float32 it directs me to the builtin Annotated class, for which serialisation works.
As for this, static type checkers don't/can't understand tensor shapes, so jaxtyping is configured to have them ignore these annotations. This is done by ensuring that Float32 behaves the same as Annotated. (In other words, what you're seeing here is a false positive.)
If I manually set the same serialisation logic as a field serialiser, i.e.
= field(serializer=tensor_serialiser)then the serialisation works as expected.
Ah, this suggests to me what a possible fix might be. So jaxtyping annotations are their own type: specifically, they are subclasses of jaxtyping.AbstractArray. Perhaps you need to register that type with pyserde – not torch.Tensor.
Ah, this suggests to me what a possible fix might be. So jaxtyping annotations are their own type: specifically, they are subclasses of
jaxtyping.AbstractArray. Perhaps you need to register that type with pyserde – nottorch.Tensor.
At the moment this seems to work in the MWE. I'll test it tomorrow in a larger codebase.
Under the assumption that this really is a solution, at least the docs should be updated to guide pyserde users to this.
In the larger codebase the approach with AbstractArray didn't work sadly.
I tested the global serializer registration, a class serializer and the field sesrializer and only the latter worked. I assume that when using the field serializer pyserde ignores the types alltogether and just calls the given functions.