jax_dataclasses
jax_dataclasses copied to clipboard
Serialization of static fields?
Thanks for the handy library!
I have a pytree_dataclass
that contains a few static_field
s that I would like to have serialized by the facilities in flax.serialize
. I noticed that jax_dataclasses.asdict
handles these, but that flax.serialization.to_state_dict
and flax.serialization.to_bytes
both ignore them. What is the correct way (if any) to have these fields included in flax
's serialization? Should I be using another technique?
import jax_dataclasses as jdc
from jax import numpy as jnp
import flax.serialization as fs
@jdc.pytree_dataclass
class Demo:
a: jnp.ndarray = jnp.ones(3)
b: bool = jdc.static_field(default=False)
demo = Demo()
print(f'{jdc.asdict(demo) = }')
print(f'{fs.to_state_dict(demo) = }')
print(f'{fs.from_bytes(Demo, fs.to_bytes(demo)) = }')
# jdc.asdict(demo) = {'a': array([1., 1., 1.]), 'b': False}
# fs.to_state_dict(demo) = {'a': DeviceArray([1., 1., 1.], dtype=float64)}
# fs.from_bytes(Demo, fs.to_bytes(demo)) = {'a': array([1., 1., 1.])}
Thanks in advance!
Appreciate the detailed example :)
To me the main reasons for excluding the static fields are:
- State dictionaries in Flax are expected to be mappings from strings to arrays (source), but the purpose of static fields is primarily to include values that aren't arrays.
- The deserialization utilities (eg
flax.serialization.from_state_dict(target, state)
) in Flax map an unpopulated pytree to a populated one with the exact same tree structure. The static fields are treated as part of the tree structure, so they should already be present in the input pytree (eg thetarget
argument).
Does that make sense?
For serializing/deserializing dataclasses, the usual pattern I use is to save two files: the serialized state dictionary from Flax (for ML applications: a "checkpoint") and a configuration object that can be passed into a helper function for instantiating the dataclass with all of the right array shapes + static fields populated (for ML applications: a "model config"). Here's an example of a helper function like this.
I think this works okay, but open to suggestions for improvements/new APIs. Maybe it's possible for a serialization helper to return a 2-element tuple consisting of both the state dictionary and the attributes needed to reproduce the tree structure + static fields?
Yes, that makes sense. Thanks. As a long-time user of pytorch, I have always enjoyed the simplicity of torch.load
and torch.save
(pickles with some extra sauce, I believe) and have also found it convenient that by default load
automatically places tensors back on the GPU from which they were save
d.
For my daily work in a Jupyter notebook in which I rewrite classes a lot, doing a straight torch.save/load
of dataclasses is fragile because the definitions of my dataclasses are in flux. The usual recommendation to save the state dict then loses the convenience of having a single file and a single load statement that remembers the class. In working with simple (non-nested) jax_dataclasses.pytree_dataclass
objects a lot lately, I have found the following to be convenient:
import torch
import jax_dataclasses as jdc
def save(dc: jdc.pytree_dataclass, filename: str) -> None:
torch.save((dc.__class__.__name__, jdc.asdict(dc)), filename)
def load(filename: str) -> jdc.pytree_dataclass:
cls_name, data = torch.load(filename)
return jax.device_put(eval(cls_name)(**data))
This way, if the given dataclass changes its definition but retains the same fields, loading still works. If the class name changes or the fields change, it's still trivial to get the state dict with a straight torch.load
. And it's fast and avoids having a sidecar file.
From a user perspective, torch.load/save
are super convenient, so something implementing a similarly fast and simple interface would be great. (Maybe a solution utilizing __getstate__
and __setstate__
somehow?)
Thanks for clarifying!
I've run into similar issues, and something like the snippet you suggested sounds really useful. Main desired features on my end would be (1) better support for nested structures and (2) possibly avoiding the eval()
call.
If your pytrees contain only objects that are serializable via PyYAML — this includes most Python objects, JAX/numpy arrays, flax FrozenDict
s, etc, but not things like lambda functions — I have some dataclass serialization utilities that gets us partially there. The basic idea is to take a class reference when deserializing, and recursively traverse the type annotations to understand how to reconstruct dataclasses. It's also pretty easy to update the YAML via a text editor if anything gets renamed.
Example of serialization:
import jax_dataclasses as jdc
from jax import numpy as jnp
import dcargs
@jdc.pytree_dataclass
class Tree:
number: float
flag: bool = jdc.static_field()
yaml = dcargs.to_yaml(Tree(3.0, False))
print(yaml)
# Recovery
dcargs.from_yaml(Tree, yaml)
Output:
!dataclass:Tree
flag: false
number: 3.0
Or, to mimic the torch.save()
and torch.load()
syntax:
import dataclasses
import pathlib
from typing import Any, Type, TypeVar, Union
import dcargs
Dataclass = Any
DataclassT = TypeVar("DataclassT")
def save(instance: Dataclass, path: Union[str, pathlib.Path]) -> None:
assert dataclasses.is_dataclass(instance)
with open(path, "w") as file:
file.write(dcargs.to_yaml(instance))
def load(cls: Type[DataclassT], path: Union[str, pathlib.Path]) -> DataclassT:
assert dataclasses.is_dataclass(cls)
with open(path, "r") as file:
output = dcargs.from_yaml(cls, file.read())
return output
Another possible source of inspiration is dacite, which should work out-of-the-box with our dataclass objects and might be used to achieve a similar goal, albeit with a separate set of constraints / possible failure cases... will continue to think about this; seems like there's room for a more robust solution.