Issues checkpointing optimizer state using Optax, nnx.Optimizer, and Orbax
I am running into an error while trying to checkpoint an Optax optimizer state, wrapped as an nnx.Optimizer, using the Orbax checkpointing library.
ValueError: Unsupported type: <class 'flax.nnx.training.optimizer.OptArray'> for key: ('0', 'count'). Supported types are (<class 'int'>, <class 'float'>, <class 'numpy.ndarray'>, <class 'jax.Array'>).
I am using packages:
- flax 0.10.2
- jax 0.4.36
- jax-cuda12-pjrt 0.4.36
- jax-cuda12-plugin 0.4.36
- jaxlib 0.4.36
- optax 0.2.4
- orbax-checkpoint 0.10.2
Minimal repro:
from flax import nnx
import numpy as np
import orbax.checkpoint as ocp
import optax
import os
import pathlib
class MyModel(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(in_features=4, out_features=1, rngs=rngs)
def __call__(self, x):
return self.linear(x)
rngs = nnx.Rngs(0)
model = MyModel(rngs)
tx = optax.adam(1e-3)
optimizer = nnx.Optimizer(model, tx)
checkpointDir = pathlib.Path('/tmp/my-checkpoints/')
checkpointer = ocp.StandardCheckpointer()
checkpointer.save(checkpointDir, optimizer.opt_state, force=True)
I see yall have documentation about using Orbax for model checkpointing, but don't see any official info about optimizer state checkpointing.
I see an older github issue where the Optax folks recommended cloudpickle: https://github.com/google-deepmind/optax/discussions/180
I tried adding a custom serialization/deserialization for nnx.training.optimizer.OptArray via https://orbax.readthedocs.io/en/latest/guides/checkpoint/custom_handlers.html#custom-serialization-deserialization, but then just ran into the next error:
ValueError: TypeHandler lookup failed for: type=<class 'flax.nnx.training.optimizer.OptVariable'>
Maybe I'd need to also add a type handler for OptVariable as well as maybe even Variable? That seems quite annoying. Am I missing something?
Thanks for your time!
It seems that also adding a type handler for nnx.training.optimizer.OptVariable is sufficient. I'd rather not need to look into the details of these classes to write serializers/deserializers though.
It seems that also adding a type handler for
nnx.training.optimizer.OptVariableis sufficient. I'd rather not need to look into the details of these classes to write serializers/deserializers though.
@SandSnip3r Im hoping I could ask if you could share what sounds like the solution was. I also have been having issues with this.
Hey @SandSnip3r, try serializing the State instead:
checkpoint = nnx.state(optimizer)
# optional but works better
checkpoint = checkpoint.to_pure_dict()
checkpointer.save(checkpointDir, checkpoint, force=True)
Then to load it you
checkpoint = checkpointer.load(...)
nnx.update(optimizer, checkpoint)
Thanks @cgarciae! I'll give this a try when I get time.
Why does .to_pure_dict() work better?
What about the arguments to checkpoint.load? The checkpointer requires a structure to load the data into, right? Like an abstract tree state? What would I use for this?
Why does .to_pure_dict() work better?
It removes the VariableState objects that contain metadata and leaves only the Arrays. It also makes checkpoint a pure dict which is usually easier to serialize.
The checkpointer requires a structure to load the data into
I think if you have a pure dictionary you don't need the target structure. Else just recreate the checkpoint structure using nnx.eval_shape e.g.
abstract_optimizer = nnx.eval_shape(lambda: create_optimizer())
target = nnx.state(abstract_optimizer).to_pure_dict()
...
@cgarciae Honestly great explanation. Based on this conversation I was able to make this. Its not perfect but this works pretty well.
class CheckpointManager:
def __init__(self, checkpoint_dir: str, keep_n: int = 3):
self.checkpoint_dir = checkpoint_dir
self.checkpointer = orbax.checkpoint.PyTreeCheckpointer()
self.keep_n = keep_n
def _cleanup_old_checkpoints(self, step: int):
# Get all checkpoint steps except 'best'
checkpoints = []
for path in self.checkpoint_dir.glob("model-*"):
if "best" not in str(path):
try:
ckpt_step = int(path.name.split("-")[-1])
checkpoints.append(ckpt_step)
except ValueError:
continue
# Sort and remove old checkpoints while keeping newest n
checkpoints.sort(reverse=True)
for old_step in checkpoints[self.keep_n :]:
for prefix in ["model", "optimizer", "metrics"]:
path = self.checkpoint_dir / f"{prefix}-{old_step}"
if path.exists():
shutil.rmtree(path)
def save_model(self, model: Any, step: int, is_best: bool = False):
state = nnx.state(model).to_pure_dict()
self._save(obj=state, filename=f"model-{step}")
if is_best:
self._save(obj=state, filename="model-best")
self._cleanup_old_checkpoints(step)
def save_optimizer(self, optimizer: nnx.Optimizer, step: int):
state = nnx.state(optimizer).to_pure_dict()
self._save(obj=state, filename=f"optimizer-{step}")
def save_training_state(self, step: int, total_tokens: int, metrics: dict = None):
metrics = {"step": step, "total_tokens": total_tokens, "metrics": metrics}
self._save(obj=metrics, filename=f"metrics-{step}")
def restore_model(self, model_cls: Any, model_config: dict, mesh: Mesh, step: int) -> Any:
abs_model = nnx.eval_shape(lambda: model_cls(**model_config, rngs=nnx.Rngs(0)))
abs_state = nnx.state(abs_model).to_pure_dict()
target = jax.tree.map(
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
abs_state,
nnx.get_named_sharding(nnx.state(abs_model), mesh),
)
state = self._restore(filename=f"model-{step}", target=target)
model = nnx.merge(abs_model, state)
return model
def restore_optimizer(
self, model: Any, train_config: TrainingConfig, step: int
) -> nnx.Optimizer:
tx = create_tx(train_config) # function where I construct my optax optimizer
abs_optimizer = nnx.eval_shape(lambda: nnx.Optimizer(model, tx))
abs_state = nnx.state(abs_optimizer).to_pure_dict()
state = self._restore(filename=f"optimizer-{step}", target=abs_state)
optimizer = nnx.Optimizer(model, tx)
nnx.update(optimizer, state)
return optimizer
def restore_training_state(self, step: int) -> dict:
"""Restore training metrics and token count."""
return self._restore(filename=f"metrics-{step}")
def _save(self, obj: Any, filename: str):
path = self.checkpoint_dir.absolute() / filename
save_args = orbax_utils.save_args_from_target(obj)
self.checkpointer.save(str(Path(path)), obj, save_args=save_args)
def _restore(self, filename: str, target: Any = None) -> Any:
path = self.checkpoint_dir.absolute() / filename
return self.checkpointer.restore(str(Path(path)), target=target)
As described, this also works without using to_pure_dict() by simply changing the last line of my repro:
checkpointer.save(checkpointDir, nnx.state(optimizer), force=True)
Then restoring can be done with
tx = optax.adam(1e-3)
optimizer = nnx.Optimizer(model, tx)
abstractOptStateTree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, nnx.state(optimizer))
optimizerState = checkpointer.restore(checkpointDir, abstractOptStateTree)
nnx.update(optimizer, optimizerState)
Thanks a lot, @cgarciae.
What do you think about adding a small blurb about saving & restoring optimizer states in the flax documentation section about checkpointing? https://flax.readthedocs.io/en/latest/guides/checkpointing.html#save-checkpoints I think this would be nice, especially since flax.nnx is offering an API for optimizers.
I'm having issues following this thread. It appears that the APIs are evolving, and I'm having a hard time getting Orbax to work with NNX. Any updated documentation somewhere with simple manager examples?
@BeeGass can you put all the imports of your manager?