jaxtyping icon indicating copy to clipboard operation
jaxtyping copied to clipboard

Jaxtyping annotations don't work with pyserde

Open Ruhrpottpatriot opened this issue 2 months ago • 5 comments

I'm using Jaxtyping to make sure that my tensors are in the correct shape. However, I also need to serialise the tensors to an API, for which I'm using pyserde with a custom global serialiser. If the types of the tensor are plain torch.Tensor the custom serialisers works and I can serialise the classes to JSON or other pyserde supported formats. However, when I annotate the tensors with Jaxtyping, e.g. Float32[torch.Tensor, "4"] the serialiser fails with the error:

serde.compat.SerdeError: Unsupported type: Tensor

Minimal (not) working example, uncomment the field to get the error:

from typing import Annotated, Any
from plum import dispatch
import torch

from jaxtyping import Float32
import serde
from serde.json import to_json


class Serializer:
    @dispatch
    def serialize(self, value: torch.Tensor) -> Any:
        return {
            "__tensor__": True,
            "dtype": str(value.dtype),
            "shape": list(value.shape),
            "data": value.cpu().tolist(),
        }


serde.add_serializer(Serializer())


@serde.serde
class Foo:
    tensor_works: torch.Tensor

    annotated_works: Annotated[torch.Tensor, "4"]

    # jax: Float32[torch.Tensor, "4"]


foo = Foo(
    tensor_works=torch.tensor([1000.0, 2000.0, 3000.0]),
    annotated_works=torch.tensor([0.0, 0.0, 0.0, 1.0], dtype=torch.float32),
    # jax=torch.tensor([100.0, 0.0, 0.0, 0.0], dtype=torch.float32), <- Doesn't work
)

j = to_json(foo)
print(j)

Ruhrpottpatriot avatar Oct 21 '25 12:10 Ruhrpottpatriot

This is a feature that I'd be happy to take a PR on :)

FWIW we currently test pickle and cloudpickle for peforming serialisation: https://github.com/patrick-kidger/jaxtyping/blob/main/test/test_serialisation.py . If you can identify some minimal set of tweaks necessary for us to work with pyserde (+add a test), then I'd be happy to make those.

patrick-kidger avatar Oct 21 '25 13:10 patrick-kidger

The problem, as far as I can tell, is that for typecheckers torch.Tensor and something like Float32[torch.Tensor, "4"] are two different types, even though they're the same at runtime. What's bugging me is that when I use VSCode to navigate to the definition of Float32 it directs me to the builtin Annotated class, for which serialisation works.

If I manually set the same serialisation logic as a field serialiser, i.e. = field(serializer=tensor_serialiser) then the serialisation works as expected. However, this means that I'd have to put this annotation in multiple dozen fields, hence the global serialiser.

Ruhrpottpatriot avatar Oct 21 '25 13:10 Ruhrpottpatriot

The problem, as far as I can tell, is that for typecheckers torch.Tensor and something like Float32[torch.Tensor, "4"] are two different types, even though they're the same at runtime.

Actually, they're different objects at runtime! The jaxtyping annotation is a custom type for which isinstance(...) will perform the shape/dtype checks.

What's bugging me is that when I use VSCode to navigate to the definition of Float32 it directs me to the builtin Annotated class, for which serialisation works.

As for this, static type checkers don't/can't understand tensor shapes, so jaxtyping is configured to have them ignore these annotations. This is done by ensuring that Float32 behaves the same as Annotated. (In other words, what you're seeing here is a false positive.)

If I manually set the same serialisation logic as a field serialiser, i.e. = field(serializer=tensor_serialiser) then the serialisation works as expected.

Ah, this suggests to me what a possible fix might be. So jaxtyping annotations are their own type: specifically, they are subclasses of jaxtyping.AbstractArray. Perhaps you need to register that type with pyserde – not torch.Tensor.

patrick-kidger avatar Oct 21 '25 14:10 patrick-kidger

Ah, this suggests to me what a possible fix might be. So jaxtyping annotations are their own type: specifically, they are subclasses of jaxtyping.AbstractArray. Perhaps you need to register that type with pyserde – not torch.Tensor.

At the moment this seems to work in the MWE. I'll test it tomorrow in a larger codebase.

Under the assumption that this really is a solution, at least the docs should be updated to guide pyserde users to this.

Ruhrpottpatriot avatar Oct 21 '25 16:10 Ruhrpottpatriot

In the larger codebase the approach with AbstractArray didn't work sadly. I tested the global serializer registration, a class serializer and the field sesrializer and only the latter worked. I assume that when using the field serializer pyserde ignores the types alltogether and just calls the given functions.

Ruhrpottpatriot avatar Oct 22 '25 13:10 Ruhrpottpatriot