Feature/Docs request: Saving and Loading checkpoints with optimizers and metrics
Dear Flax team,
I've been setting up all the infrastructure to save and load my flax models using Orbax and the guide in the documentation.
However, I noticed that a typical pattern when dealing with a flax model is to bundle the optimizer and the metrics:
tx = optax.chain(
...
)
graphdef, graphstate = nnx.split(
(model, nnx.Optimizer(model, tx, wrt=nnx.Param), metrics, ...)
)
model, optimizer, *rest = nnx.merge(graphdef, graphstate)
Using Orbrax is trivial to save and restore the graphstate. However, the graphdef is a whole other story. To reconstruct it, one would need to know exactly how it was created. What I've been doing so far is saving the import math of the model class and the constructor parameters as metadata to reconstruct the model, then splitting it, loading the graphstate, and replacing it. But when I was going to use it for a real task, I ran into the problem of reconstructing the graphdef when there is more than the model in there. The docs don't show how to do this.
The minor issue is separating the model graphstate from the rest of the things, and at least saving the model.
The larger issue would be how to manage the bundle of information in graphstate and graphdef.
A nice solution could be to create a to_json and from_json constructor for graphdef to make it easy to resialize with Orbax. This way, one would save the graphdef as metadata, and reconstruction would be as simple as loading both elements, constructing a new graphdef from the JSON, and then merging.
Best regards.
@cdelv thanks for question! I would like to understand better your question. I agree that orbax is mostly used to store and restore the state:
# Store
_, state = nnx.split(model) # where model is an instance of Model
checkpointer.save(ckpt_dir / 'state', state)
# Restore
abstract_model = nnx.eval_shape(lambda: Model(...))
graphdef, abstract_state = nnx.split(abstract_model)
state_restored = checkpointer.restore(ckpt_dir / 'state', abstract_state)
model = nnx.merge(graphdef, state_restored)
In this code we assume that we can always create the model instance of its class in the code and just save and load its parameters.
Seems like your question is about how to serialize and deserialize nnx.Module itself. Right now I'm not totally sure how this can be done with pure Flax, but I'll ask around (also if it would make sense to have a to_json and from_json features).
I can think of 3rd party configuration libs like hydra where we can define the model configuration via a yaml config file and instantiate it in the code assuming that the class definition exists (https://hydra.cc/docs/advanced/instantiate_objects/overview/#simple-usage):
(I haven't tested below example)
# Config yaml file
model:
_target_: my_app.Model
in_dim: 32
hidden_dim: 128
out_dim: 10
rng_key: 17
class Model(nnx.Module):
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, rng_key: int):
rngs = nnx.Rngs(rng_key)
self.fc1 = nnx.Linear(in_dim, hidden_dim, rngs=rngs)
self.fc1 = nnx.Linear(hidden_dim, out_dim, rngs=rngs)
...
model = instantiate(cfg.model)
graphdef, _ = nnx.split(model)
A while ago I experimented with writing something like this for flax, you can find it here https://github.com/NeuralQXLab/nqxpack .
Essentially I wanted to be able to do nqxpack.save(nnx_network, filename) and nqxpack.load(filename).
Getting it to work for flax.linen was remarkably easy, as the linen.Module is really just a data class, so you can recursively serialise it and its attributes by calling dataclasses.fields() on it.
As long as YOU do not change the class definition, you will be able to reload it no problem.
For example the following
import nqxpack
import jax
from flax import linen as nn
import numpy as np
model = nn.Sequential((
nn.Dense(features=2),
nn.gelu,
nn.Dense(features=1),
jax.numpy.squeeze,
))
variables = model.init(jax.random.key(1), jax.numpy.ones((2,4)))
variables_np = jax.tree.map(np.asarray, variables)
# for the moment cannot serialise jax arrays.
# Could easily be implemented
nqxpack.save({'model':model, 'variables':jax.tree.map(np.asarray, variables)}, "mymodel.nk")
loaded_dict = nqxpack.load("mymodel.nk")
loaded_model, loaded_variables = loaded_dict['model'], loaded_dict['variables']
gets serialised to an hydra-inspired json
"model": {
"_target_": "flax.linen.combinators.Sequential",
"layers": {
"_target_": "builtins.tuple",
"_args_": [
[
{
"_target_": "flax.linen.linear.Dense",
"features": 2,
"use_bias": true,
"dtype": null,
"param_dtype": "< jax.numpy.float32 >",
"precision": null,
"kernel_init": {
"_target_": "jax.nn.initializers.variance_scaling",
"batch_axis": {
"_target_": "builtins.tuple"
},
"distribution": "truncated_normal",
"in_axis": -2,
"mode": "fan_in",
"out_axis": -1,
"scale": 1.0
},
"bias_init": "< jax.nn.initializers.zeros >",
...
},
For flax.nnx's graphed, instead... the story is much more complicated.
The graphdef itself is a pytree, but a very complex one.
In the last ~6 months its internal structure has changed a few times, and it's unclear to me what is the best way to handle it.
A few months ago I could usually serialise nnx graphdefs with that package, but nowadays it does not work anymore.
I long wanted to check with @cgarciae what would be the best way to handle this for nnx.