Restoring flax TrainState from two different checkpoint structures in the new API
I'm trying to migrate from an older version of orbax (v0.6.4) to a newer one (v0.11.10), and I have a problem with restoring checkpoints.
I have two similar checkpoints, and I need to be able to load both of them.
Newer checkpoint version
$ tree -a /tmp/physmodjax/checkpoints_11po8hwr:v0/
/tmp/physmodjax/checkpoints_11po8hwr:v0/
├── checkpoints
│ └── 2351
│ ├── _CHECKPOINT_METADATA
│ ├── metrics
│ │ └── metrics
│ └── state
│ ├── d
│ │ └── 28136664d2dc46fa105f6e6d9bb416fa
│ ├── manifest.ocdbt
│ ├── _METADATA
│ ├── ocdbt.process_0
│ │ ├── d
│ │ │ ├── 23a6136b0d9ee643613c3b4a996bd1f8
│ │ │ ├── 4cf19a245d689524d873048019107e95
│ │ │ ├── 800aa1a9a609496e5ebb84e79682f962
│ │ │ ├── 870d7e0a616a05e2f894a49b684cb51e
│ │ │ ├── a275135cc9f930113aae8d2b20689de1
│ │ │ ├── e7cf94bd5e9366bbdbb30556eb89136d
│ │ │ └── ee2661f94896574e87a14149c0980612
│ │ └── manifest.ocdbt
│ └── _sharding
└── .hydra
└── config.yaml
Older checkpoint version
$ tree -a /tmp/physmodjax/checkpoints_4sa4dawx:v0/
/tmp/physmodjax/checkpoints_4sa4dawx:v0/
├── checkpoints
│ └── 851
│ ├── _CHECKPOINT_METADATA
│ ├── default
│ │ ├── d
│ │ │ └── b45c3d9920fabadd4c2813f6219571b0
│ │ ├── manifest.ocdbt
│ │ ├── _METADATA
│ │ ├── ocdbt.process_0
│ │ │ ├── d
│ │ │ │ ├── 6b4ac6e43f39060f80d607021694e22a
│ │ │ │ ├── c01a0dffd52b512886457d45d527c5ac
│ │ │ │ ├── cd854566ac3aeb55292d17b25f6b6ed6
│ │ │ │ ├── cfe86c22634cbd3abb8d270f9dfc9d8d
│ │ │ │ ├── d448eb8462c257f7522a708129dcd17b
│ │ │ │ └── d7d7efe7b1a749042abf00025dbd7233
│ │ │ └── manifest.ocdbt
│ │ └── _sharding
│ └── metrics
│ └── metrics
└── .hydra
└── config.yaml
The current code, corresponding to the newer checkpoint version, to save the checkpoint is:
options = hydra.utils.instantiate(cfg.checkpoint_manager_options)
with obc.CheckpointManager(
directory=Path(output_dir) / "checkpoints",
options=options,
item_handlers={"state": obc.PyTreeCheckpointHandler()},
) as checkpoint_manager:
_ = train(
model_cls=model_cls,
datamodule=datamodule,
cfg=cfg,
checkpoint_manager=checkpoint_manager,
)
checkpoint_manager.wait_until_finished()
Inside the train function, in the training loop, I have the following code to save the checkpoint:
checkpoint_manager.save(
step=epoch,
args=obc.args.Composite(
state=obc.args.PyTreeSave(state),
),
metrics=metrics_to_log,
)
With the following code I was able to restore both types of checkpoints with v0.6.4:
import jax
jax.config.update("jax_platforms", "cuda")
print(jax.config.jax_platforms)
import hydra
from omegaconf import OmegaConf
from pathlib import Path
from flax.training import train_state
import flax.linen as nn
import orbax.checkpoint as obc
from typing import Any
def restore_experiment_state(
run_path: Path, # Path to the run directory (e.g. "outputs/2024-01-23/22-15-11")
best: bool = True, # If True, restore the best checkpoint instead of the latest
step_to_restore: int = None, # If not None, restore the checkpoint at this step
kwargs: dict = {}, # Additional arguments to pass to the model
) -> tuple[train_state.TrainState, nn.Module, obc.CheckpointManager]:
"""
Restores the train state from a run.
Args:
run_path (Path): Path to the run directory (e.g. "outputs/2024-01-23/22-15-11")
Returns:
-------
train_state.TrainState: The train state of the experiment
nn.Module: The model used in the experiment
CheckpointManager: The checkpoint manager
"""
# Make sure the path is a Path object
run_path = Path(run_path)
# These are hardcoded, do not change
ckpt_path = run_path / "checkpoints"
config_path = run_path / ".hydra" / "config.yaml"
cfg = OmegaConf.load(config_path)
options = obc.CheckpointManagerOptions(
max_to_keep=1,
create=True,
best_fn=lambda x: float(
x["val/mae_rel"]
), # Shouldn't be hardcoded here, not a problem atm because we only save one step, best
best_mode="min",
)
with obc.CheckpointManager(
ckpt_path,
options=options,
item_handlers={
"state": obc.PyTreeCheckpointHandler(),
"default": obc.PyTreeCheckpointHandler(),
},
) as checkpoint_manager:
model_cls: nn.Module = hydra.utils.instantiate(cfg.model)
model = model_cls(training=False, **kwargs)
# Get checkpoint metadata
step = (
checkpoint_manager.latest_step()
if not best
else checkpoint_manager.best_step()
)
step = step_to_restore if step_to_restore is not None else step
metadatas = checkpoint_manager.item_metadata(step)
print(f"Restoring checkpoint from step {step}...")
# Backwards compatibility for older checkpoints
if "state" in metadatas and metadatas.state is not None:
metadata_state = metadatas.state
ckpt_type = "state"
print("This is a checkpoint with new formatting")
elif "default" in metadatas and metadatas.default is not None:
print("This is a checkpoint with old formatting")
assert "model" in metadatas.default, "No model found in the checkpoint"
metadata_state = metadatas.default["model"]
ckpt_type = "default"
else:
raise ValueError("No state found in the checkpoint")
# Check if the checkpoint has batch_stats
if "batch_stats" in metadata_state:
# Define TrainState with optional batch_stats
class TrainState(train_state.TrainState):
key: jax.Array
batch_stats: Any = None # Optional field
# Initialize the empty state
empty_state = TrainState(
key={},
step=0,
apply_fn=model.apply,
params=metadata_state["params"],
tx={},
opt_state=metadata_state["opt_state"],
batch_stats=metadata_state["batch_stats"],
)
else:
# Define TrainState with optional batch_stats
class TrainState(train_state.TrainState):
key: jax.Array
empty_state = TrainState(
key={},
step=0,
apply_fn=model.apply,
params=metadata_state["params"],
tx={},
opt_state=metadata_state["opt_state"],
)
old_ckpt = {"model": empty_state}
restored_checkpoint = checkpoint_manager.restore(
step=step,
args=obc.args.Composite(
default=obc.args.PyTreeRestore(
item=old_ckpt,
),
state=obc.args.PyTreeRestore(
item=empty_state,
),
),
)
if ckpt_type == "state":
state = restored_checkpoint.state
elif ckpt_type == "default":
state = restored_checkpoint.default["model"]
return state, model, checkpoint_manager
checkpoint_path = Path("/tmp/physmodjax/checkpoints_11po8hwr:v0/") # Newer checkpoint version
checkpoint_path = Path("/tmp/physmodjax/checkpoints_4sa4dawx:v0/") # Older checkpoint version
conf = OmegaConf.load(checkpoint_path / ".hydra" / "config.yaml")
kwargs = {"n_steps": conf.datamodule.num_steps_train[1]}
state, model, ckpt_manager = restore_experiment_state(
checkpoint_path,
kwargs=kwargs,
)
print("Restored model!!")
Although it works only for restoring to the same backend (cuda), only. It gives me a warning regarding this
$ python simple_restore_issue.py
cuda
Restoring checkpoint from step 851...
This is a checkpoint with old formatting
/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1330: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
Restored model!!
I was trying to get it to load correctly to both cpu and cuda backends doing shenanigans like
def apply_sharding(array, sharding):
if isinstance(array, obc.metadata.value.ArrayMetadata):
array.sharding = sharding
default_sharding = obc.metadata.sharding.SingleDeviceShardingMetadata(device_str=str(jax.local_devices()[0]))
jax.tree_util.tree_map(
lambda x: apply_sharding(x, default_sharding),
metadata_state,
)
metadata_state = apply_default_sharding(metadata_state)
Inside the restore_experiment_state function, but after bashing my head into a wall for hours, I decided to update orbax to a newer version (v0.11.10) to use the new API and not code against a deprecated version.
But now I'm back at square one, because with v0.11.10 I can't restore the existing checkpoints, not even on the same backend.
Newer checkpoint version:
$ python simple_restore_issue.py
jax.config.jax_platforms: cuda
Restoring checkpoint from step 2351...
This is a checkpoint with new formatting
Traceback (most recent call last):
File "/home/carlos/projects/physmodjax_private/examples/evaluation/mlsp25/simple_restore_issue.py", line 144, in <module>
state, model, ckpt_manager = restore_experiment_state(
File "/home/carlos/projects/physmodjax_private/examples/evaluation/mlsp25/simple_restore_issue.py", line 116, in restore_experiment_state
restored_checkpoint = checkpoint_manager.restore(
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1566, in restore
restored = self._checkpointer.restore(restore_directory, args=args)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 545, in restore
return super().restore(directory, *args, **kwargs)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/checkpointer.py", line 289, in restore
restored = self._restore(directory, args=ckpt_args)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/checkpointer.py", line 308, in _restore
return self._handler.restore(directory, args=args)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py", line 831, in restore
raise KeyError(
KeyError: 'Item "default" was not found in the checkpoint. Available items: [\'metrics\', \'state\']'
Older checkpoint version:
$ python simple_restore_issue.py
jax.config.jax_platforms: cuda
Restoring checkpoint from step 851...
This is a checkpoint with old formatting
/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1250: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.
warnings.warn(
WARNING:absl:[process=0][thread=MainThread] No metadata found for any process_index, checkpoint_dir=/tmp/physmodjax/checkpoints_4sa4dawx:v0/checkpoints/851/default. time elapsed=0.00021982192993164062 seconds. If the checkpoint does not contain jax.Array then it is expected. If checkpoint contains jax.Array then it should lead to an error eventually; if no error is raised then it is a bug.
Traceback (most recent call last):
File "/home/carlos/projects/physmodjax_private/examples/evaluation/mlsp25/simple_restore_issue.py", line 144, in <module>
state, model, ckpt_manager = restore_experiment_state(
File "/home/carlos/projects/physmodjax_private/examples/evaluation/mlsp25/simple_restore_issue.py", line 116, in restore_experiment_state
restored_checkpoint = checkpoint_manager.restore(
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1566, in restore
restored = self._checkpointer.restore(restore_directory, args=args)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/async_checkpointer.py", line 545, in restore
return super().restore(directory, *args, **kwargs)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/checkpointer.py", line 289, in restore
restored = self._restore(directory, args=ckpt_args)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/checkpointers/checkpointer.py", line 308, in _restore
return self._handler.restore(directory, args=args)
File "/home/carlos/projects/physmodjax_private/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py", line 831, in restore
raise KeyError(
KeyError: 'Item "state" was not found in the checkpoint. Available items: [\'default\', \'metrics\']'
What is the best way to do this? I feel like I'm going crazy, thanks a lot for any help you can provide.
Bonus points if you can help me restore the checkpoints to both cpu and cuda backends as well, if not I will have to possibly open an issue later for that specifically if I get this working.
Part of the problem in restoring on different backends is that you're not specifying the shardings for PyTreeRestore.
restored_checkpoint = checkpoint_manager.restore(
step=step,
args=obc.args.Composite(
default=obc.args.PyTreeRestore(
item=old_ckpt,
),
state=obc.args.PyTreeRestore(
item=empty_state,
),
),
PyTreeRestore needs item and also restore_args. Check out checkpoint_utils.construct_restore_args. We generally recommend that folks use StandardSave and StandardRestore though, it's a bit easier to work with. StandardRestore just requires that you pass a tree containing abstract leaves (jax.ShapeDtypeStruct). These leaves should set a sharding property for restoring on different backends without the sharding metadata file needed.
For the second part of your question, are you still calling restore with the way shown above, where both default and state are being passed? I think older versions of the API may have allowed this, since we were more reliant on up-front configuration of the handlers, we might be ignoring the item names that are not present. Now though, if you restore with state and default, it's going to go looking for state and default and error out if one is not found. You should simply check the checkpoint before calling restore to see which item name is actually available, and restore with the arguments corresponding to the checkpoint.
Hi, sorry for the late response, other things got in the way. Thanks a lot for the help. The second part was indeed me being a bit thick. I have now managed to make things work, but by creating a pytree for restore_args setting the restore_type to be np.ndarray, restoring the metadata with that, creating the flax TrainState state with the restored metadata and then moving the state to the right device. I couldn't figure out if there was a better way to do it using construct_restore_args. Is there a better way? How would I create a tree with abstract leaves from the metadata?
import jax
jax.config.update("jax_platforms", "cpu")
print(jax.config.jax_platforms)
import hydra
from omegaconf import OmegaConf
from pathlib import Path
from flax.training import train_state
import flax.linen as nn
import orbax.checkpoint as obc
def restore_experiment_state(
run_path: Path, # Path to the run directory (e.g. "outputs/2024-01-23/22-15-11")
best: bool = True, # If True, restore the best checkpoint instead of the latest
best_metric: str = "val/mae_rel", # Metric to use for best checkpoint selection
step_to_restore: int = None, # If not None, restore the checkpoint at this step, incompatible with `best`
kwargs: dict = {}, # Additional arguments to pass to the model
device: jax.Device = None, # Device to restore the model on
) -> tuple[train_state.TrainState, nn.Module, obc.CheckpointManager]:
"""
Restores the train state from a run.
Args:
run_path (Path): Path to the run directory (e.g. "outputs/2024-01-23/22-15-11")
best (bool): If True, restore the best checkpoint instead of the latest
step_to_restore (int): If not None, restore the checkpoint at this step
kwargs (dict): Additional arguments to pass to the model
device (jax.Device): Device to restore the model on
Returns:
-------
train_state.TrainState: The train state of the experiment
nn.Module: The model used in the experiment
CheckpointManager: The checkpoint manager
"""
# Make sure the path is a Path object
run_path = Path(run_path)
# These are hardcoded, do not change
ckpt_path = run_path / "checkpoints"
config_path = run_path / ".hydra" / "config.yaml"
cfg = OmegaConf.load(config_path)
# Check if either `best` or `step_to_restore` is set, but not both
if best:
if step_to_restore is not None:
raise ValueError(
"You cannot set both `best` and `step_to_restore`. Please choose one."
)
if best_metric is None:
raise ValueError(
"If `best=True`, you must provide a `best_metric` to determine the best checkpoint."
)
options = obc.CheckpointManagerOptions(
max_to_keep=1,
create=True,
best_fn=lambda x: float(
x[best_metric]
), # Shouldn't be hardcoded here, not a problem atm because we only save one step, best
best_mode="min",
)
def set_restore_type(x: Any) -> obc.RestoreArgs:
return obc.RestoreArgs(restore_type=np.ndarray)
with obc.CheckpointManager(
ckpt_path,
options=options,
item_handlers={
"state": obc.PyTreeCheckpointHandler(),
"default": obc.PyTreeCheckpointHandler(),
},
) as checkpoint_manager:
model_cls: nn.Module = hydra.utils.instantiate(cfg.model)
model = model_cls(training=False, **kwargs)
# Get checkpoint metadata
step = (
checkpoint_manager.latest_step()
if not best
else checkpoint_manager.best_step()
)
step = step_to_restore if step_to_restore is not None else step
metadatas = checkpoint_manager.item_metadata(step)
print(f"Restoring checkpoint from step {step}...")
# Restore the metadata
# Backwards compatibility for older checkpoints
if "state" in metadatas and metadatas.state is not None:
restore_args_state = jax.tree_util.tree_map(
set_restore_type, metadatas["state"]
)
metadatas = checkpoint_manager.restore(
step=step,
args=obc.args.Composite(
state=obc.args.PyTreeRestore(
item=metadatas["state"],
restore_args=restore_args_state,
),
),
)
metadata_state = metadatas.state
elif "default" in metadatas and metadatas.default is not None:
print("This is a checkpoint with old formatting")
if "model" not in metadatas.default:
raise ValueError("No model found in the checkpoint")
restore_args_default = jax.tree_util.tree_map(
set_restore_type, metadatas["default"]
)
metadatas = checkpoint_manager.restore(
step=step,
args=obc.args.Composite(
default=obc.args.PyTreeRestore(
item=metadatas["default"],
restore_args=restore_args_default,
),
),
)
metadata_state = metadatas.default["model"]
else:
raise ValueError("No state found in the checkpoint")
if "batch_stats" in metadata_state:
# Define TrainState with optional batch_stats
class TrainState(train_state.TrainState):
key: jax.Array
batch_stats: Any = None # Optional field
# Initialize the empty state
state = TrainState(
key={},
step=0,
apply_fn=model.apply,
params=metadata_state["params"],
tx={},
opt_state=metadata_state["opt_state"],
batch_stats=metadata_state["batch_stats"],
)
else:
# Define TrainState with optional batch_stats
class TrainState(train_state.TrainState):
key: jax.Array
state = TrainState(
key={},
step=0,
apply_fn=model.apply,
params=metadata_state["params"],
tx={},
opt_state=metadata_state["opt_state"],
)
# If device is not specified, use the default device
if device is None:
device = jax.devices()[0]
# Move the state to the specified device
state = jax.device_put(state, device)
return state, model, checkpoint_manager
How would I create a tree with abstract leaves from the metadata?
Metadata will return you a tree with leaves that are array-metadata-like objects, containing typical properties like shape, dtype, and sharding. https://github.com/google/orbax/blob/52251712d171ceb7417eb6bdb574059ced0cc3e9/checkpoint/orbax/checkpoint/_src/metadata/value.py#L63. sharding is a ShardingMetadata (which is abstract) but depending on the actual type (e.g. NamedShardingMetadata), it has properties that you can use to figure out what the original sharding was.
But I'm assuming you don't actually need that sharding info because you will specify your own. You can just tree_map given the metadata to create a tree of jax.ShapeDtypeStruct, specifying a new sharding for each array as desired.