flax icon indicating copy to clipboard operation
flax copied to clipboard

Issues checkpointing optimizer state using Optax, nnx.Optimizer, and Orbax

Open SandSnip3r opened this issue 1 year ago • 8 comments

I am running into an error while trying to checkpoint an Optax optimizer state, wrapped as an nnx.Optimizer, using the Orbax checkpointing library.

ValueError: Unsupported type: <class 'flax.nnx.training.optimizer.OptArray'> for key: ('0', 'count'). Supported types are (<class 'int'>, <class 'float'>, <class 'numpy.ndarray'>, <class 'jax.Array'>).

I am using packages:

  • flax 0.10.2
  • jax 0.4.36
  • jax-cuda12-pjrt 0.4.36
  • jax-cuda12-plugin 0.4.36
  • jaxlib 0.4.36
  • optax 0.2.4
  • orbax-checkpoint 0.10.2

Minimal repro:

from flax import nnx
import numpy as np
import orbax.checkpoint as ocp
import optax
import os
import pathlib

class MyModel(nnx.Module):
  def __init__(self, rngs):
    self.linear = nnx.Linear(in_features=4, out_features=1, rngs=rngs)

  def __call__(self, x):
    return self.linear(x)

rngs = nnx.Rngs(0)
model = MyModel(rngs)

tx = optax.adam(1e-3)
optimizer = nnx.Optimizer(model, tx)

checkpointDir = pathlib.Path('/tmp/my-checkpoints/')
checkpointer = ocp.StandardCheckpointer()

checkpointer.save(checkpointDir, optimizer.opt_state, force=True)

I see yall have documentation about using Orbax for model checkpointing, but don't see any official info about optimizer state checkpointing.

I see an older github issue where the Optax folks recommended cloudpickle: https://github.com/google-deepmind/optax/discussions/180

I tried adding a custom serialization/deserialization for nnx.training.optimizer.OptArray via https://orbax.readthedocs.io/en/latest/guides/checkpoint/custom_handlers.html#custom-serialization-deserialization, but then just ran into the next error:

ValueError: TypeHandler lookup failed for: type=<class 'flax.nnx.training.optimizer.OptVariable'>

Maybe I'd need to also add a type handler for OptVariable as well as maybe even Variable? That seems quite annoying. Am I missing something?

Thanks for your time!

SandSnip3r avatar Dec 09 '24 17:12 SandSnip3r

It seems that also adding a type handler for nnx.training.optimizer.OptVariable is sufficient. I'd rather not need to look into the details of these classes to write serializers/deserializers though.

SandSnip3r avatar Dec 09 '24 18:12 SandSnip3r

It seems that also adding a type handler for nnx.training.optimizer.OptVariable is sufficient. I'd rather not need to look into the details of these classes to write serializers/deserializers though.

@SandSnip3r Im hoping I could ask if you could share what sounds like the solution was. I also have been having issues with this.

BeeGass avatar Dec 10 '24 21:12 BeeGass

Hey @SandSnip3r, try serializing the State instead:

checkpoint = nnx.state(optimizer)
# optional but works better
checkpoint = checkpoint.to_pure_dict()

checkpointer.save(checkpointDir, checkpoint, force=True)

Then to load it you

checkpoint = checkpointer.load(...)
nnx.update(optimizer, checkpoint)

cgarciae avatar Dec 10 '24 22:12 cgarciae

Thanks @cgarciae! I'll give this a try when I get time.

Why does .to_pure_dict() work better?

What about the arguments to checkpoint.load? The checkpointer requires a structure to load the data into, right? Like an abstract tree state? What would I use for this?

SandSnip3r avatar Dec 10 '24 22:12 SandSnip3r

Why does .to_pure_dict() work better?

It removes the VariableState objects that contain metadata and leaves only the Arrays. It also makes checkpoint a pure dict which is usually easier to serialize.

The checkpointer requires a structure to load the data into

I think if you have a pure dictionary you don't need the target structure. Else just recreate the checkpoint structure using nnx.eval_shape e.g.

abstract_optimizer = nnx.eval_shape(lambda: create_optimizer())
target = nnx.state(abstract_optimizer).to_pure_dict()
...

cgarciae avatar Dec 10 '24 22:12 cgarciae

@cgarciae Honestly great explanation. Based on this conversation I was able to make this. Its not perfect but this works pretty well.

class CheckpointManager:
    def __init__(self, checkpoint_dir: str, keep_n: int = 3):
        self.checkpoint_dir = checkpoint_dir
        self.checkpointer = orbax.checkpoint.PyTreeCheckpointer()
        self.keep_n = keep_n

    def _cleanup_old_checkpoints(self, step: int):
        # Get all checkpoint steps except 'best'
        checkpoints = []
        for path in self.checkpoint_dir.glob("model-*"):
            if "best" not in str(path):
                try:
                    ckpt_step = int(path.name.split("-")[-1])
                    checkpoints.append(ckpt_step)
                except ValueError:
                    continue

        # Sort and remove old checkpoints while keeping newest n
        checkpoints.sort(reverse=True)
        for old_step in checkpoints[self.keep_n :]:
            for prefix in ["model", "optimizer", "metrics"]:
                path = self.checkpoint_dir / f"{prefix}-{old_step}"
                if path.exists():
                    shutil.rmtree(path)

    def save_model(self, model: Any, step: int, is_best: bool = False):
        state = nnx.state(model).to_pure_dict()
        self._save(obj=state, filename=f"model-{step}")
        if is_best:
            self._save(obj=state, filename="model-best")
        self._cleanup_old_checkpoints(step)

    def save_optimizer(self, optimizer: nnx.Optimizer, step: int):
        state = nnx.state(optimizer).to_pure_dict()
        self._save(obj=state, filename=f"optimizer-{step}")

    def save_training_state(self, step: int, total_tokens: int, metrics: dict = None):
        metrics = {"step": step, "total_tokens": total_tokens, "metrics": metrics}
        self._save(obj=metrics, filename=f"metrics-{step}")

    def restore_model(self, model_cls: Any, model_config: dict, mesh: Mesh, step: int) -> Any:
        abs_model = nnx.eval_shape(lambda: model_cls(**model_config, rngs=nnx.Rngs(0)))
        abs_state = nnx.state(abs_model).to_pure_dict()

        target = jax.tree.map(
            lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
            abs_state,
            nnx.get_named_sharding(nnx.state(abs_model), mesh),
        )

        state = self._restore(filename=f"model-{step}", target=target)
        model = nnx.merge(abs_model, state)
        return model

    def restore_optimizer(
        self, model: Any, train_config: TrainingConfig, step: int
    ) -> nnx.Optimizer:
        tx = create_tx(train_config) # function where I construct my optax optimizer
        abs_optimizer = nnx.eval_shape(lambda: nnx.Optimizer(model, tx))
        abs_state = nnx.state(abs_optimizer).to_pure_dict()

        state = self._restore(filename=f"optimizer-{step}", target=abs_state)
        optimizer = nnx.Optimizer(model, tx)
        nnx.update(optimizer, state)
        return optimizer

    def restore_training_state(self, step: int) -> dict:
        """Restore training metrics and token count."""
        return self._restore(filename=f"metrics-{step}")

    def _save(self, obj: Any, filename: str):
        path = self.checkpoint_dir.absolute() / filename
        save_args = orbax_utils.save_args_from_target(obj)
        self.checkpointer.save(str(Path(path)), obj, save_args=save_args)

    def _restore(self, filename: str, target: Any = None) -> Any:
        path = self.checkpoint_dir.absolute() / filename
        return self.checkpointer.restore(str(Path(path)), target=target)

BeeGass avatar Dec 11 '24 07:12 BeeGass

As described, this also works without using to_pure_dict() by simply changing the last line of my repro:

checkpointer.save(checkpointDir, nnx.state(optimizer), force=True)

Then restoring can be done with

tx = optax.adam(1e-3)
optimizer = nnx.Optimizer(model, tx)
abstractOptStateTree = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, nnx.state(optimizer))
optimizerState = checkpointer.restore(checkpointDir, abstractOptStateTree)
nnx.update(optimizer, optimizerState)

Thanks a lot, @cgarciae.

What do you think about adding a small blurb about saving & restoring optimizer states in the flax documentation section about checkpointing? https://flax.readthedocs.io/en/latest/guides/checkpointing.html#save-checkpoints I think this would be nice, especially since flax.nnx is offering an API for optimizers.

SandSnip3r avatar Dec 11 '24 15:12 SandSnip3r

I'm having issues following this thread. It appears that the APIs are evolving, and I'm having a hard time getting Orbax to work with NNX. Any updated documentation somewhere with simple manager examples?

@BeeGass can you put all the imports of your manager?

mfouesneau avatar Jul 07 '25 08:07 mfouesneau