equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Equinox very nearly supports serialization/deserialization with pydantic.

Open benmacadam64 opened this issue 7 months ago • 1 comments

Hi there,

This is more of an observation/proposal. I would like to jump into the equinox codebase and add some of these features myself, if this is something you'd be interested in.

I've found checkpointing, and more generally serialization/deserialization of models with Equinox to be a bit of a pain point. I spent some time playing around yesterday and it seems like equinox very nearly supports serialization to/from pydantic by making pydantic serialize/deserialize jax arrays to ShapeDtypeStruct (I only really looked at shape/dtype, I didn't dig into sharding etc.).

import jax
from jax import ShapeDtypeStruct as S
from pydantic_core import core_schema

def _array_core_schema(cls, source_type, handler):
    def _validate(v):
        if isinstance(v, jax.Array):
            return v
        if isinstance(v, dict) and {'shape', 'dtype'} <= v.keys():
            return S(tuple(v['shape']), jnp.dtype(v['dtype']))
        raise TypeError('Expected jax.Array or {"shape": ... "dtype": ...}')

    def _serialize(v):
        return {'shape': list(v.shape), 'dtype': str(v.dtype)}

    return core_schema.no_info_plain_validator_function(
        _validate,
        serialization=core_schema.plain_serializer_function_ser_schema(_serialize),
    )

# attach to the class (works for JAX ≥0.4.14 where Array is a real type)
jax.Array.__get_pydantic_core_schema__ = classmethod(_array_core_schema)

That let me serialize a few basic models (just using constructions like nn.Sequential, nn.Linear, etc) after wrapping the model in a BaseModel (I'm sure a TypeAdapter would work just as well). I think a few layers would require minor changes (like some of the pooling functions), and users would need to be mindful of how they used nn.Lambda if they want their model to serialize easily.

I would propose:

  1. Making this opt-in behaviour (e.g. having equinox[pydantic] as an extra option).
  2. Adding a better version of the _array_core_schema function that supports sharding, if that sharding data can easily be serialized.
  3. Making minor changes to a few built-in layers to ensure they are pydantic-serializable, and perhaps adding a callable_dataclass_factory function.
  4. Maybe some utilities to make using checkpointing with tools like orbax/weights and biases supported out of the box?

I don't think it's necessary to modify eqx.Module to support pydantic, since TypeAdapter seems to do the job.

I think the most problematic part of this is probably modifying the behaviour of jax.Array to support Pydantic outside of the jax library, which is why I would think this should be opt-in behaviour (and would understand if you think it's a complete no-go).

benmacadam64 avatar Jun 13 '25 18:06 benmacadam64

Hi there! I'm not a pydantic user, so I'm not super familiar with the details of their API. But if I understand you correctly then I think:

  • As per your last comment, indeed we probably can't do anything here like _array_core_schema, since that's really a JAX-level thing not an Equinox-level thing.
  • But if there are backward-compatible tweaks to eqx.nn.* that would make it easier for you to build pydantic-compatible serialization then I'd be very happy to take those as a PR! :) (I wouldn't add a callable_dataclass_factory, that's basically what subclassing eqx.Module already is today.)

FWIW for serialization/checkpointing I usually recommend this example. I think something of this form should be compatible with orbax/wandb/etc, although in the same spirit as above if the experience can be streamlined then I'd be happy to hear suggestions!

patrick-kidger avatar Jun 14 '25 21:06 patrick-kidger