jax
jax copied to clipboard
[export] Add support for serialization for some custom PyTree nodes
See the added documentation for jax._src.export.register_pytree_node_serialization
and jax._src.export.register_namedtuple_serialization
.
Serialization of PyTree nodes is needed to serialize the in_tree
and out_tree
of Exported
function (not to serialize actual instances of the custom types).
When writing this I have looked at how TensorFlow handles namedtuple. It does so transparently, without requiring the user to register a serialization handler for the namedtuple type. But this has the disadvantage that on deserializaton a fresh distinct namedtuple type is created for each input and output type of the serialized function. This means that calling the deserialized function will return outputs of different types than then function that was serialized. This can be confusing.
The Python pickle mode does a bit better: it attempts to look up the namedtuple type as a module attribute in the deserializing code, importing automatically the module whose name was saved during serialization. This is too much magic for my taste, and can result in strange import errors.
Hence I added an explicit step for the user to say how they want the namedtuple to be serialized and deserialized.
Since I wanted to also add support for collections.OrderedDict
, which users are asking for, I added more general support for PyTree custom nodes. Note that this registration mechanism works in conjunction with the PyTree custom node registration mechanism. The burden is on the user to decide how to serialize and deserialize the custom auxdata that the PyTree custom registration mechanism uses. Not all custom types will be serializable, but many commonly used ones, e.g., dataclasses, can now be inputs and outputs of the serialized functions.