orbax icon indicating copy to clipboard operation
orbax copied to clipboard

`step_prefix` cannot contain `_` -- Checkpoint manager does not recognized multiple `_`.

Open scott-yj-yang opened this issue 10 months ago • 2 comments

Bug Description:

When I created a checkpoint manager option like the following,

options = ocp.CheckpointManagerOptions(step_prefix="ppo_networks")
with ocp.CheckpointManager(
    ".../model_checkpoints/7358284e-a603-453f-9024-f69a27a293c4",
    options=options,
) as mngr:
    mngr.restore(0)

with my directory looks like this

Image

it gives me an value error of the following when instantiating the manager object.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[13], line 5
      1 import orbax.checkpoint as ocp
      4 options = ocp.CheckpointManagerOptions(step_prefix="ppo_networks")
----> 5 with ocp.CheckpointManager(
      6     "/root/vast/scott-yang/track-mjx/model_checkpoints/7358284e-a603-453f-9024-f69a27a293c4",
      7     options=options,
      8 ) as mngr:
      9     mngr.restore(0)

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py:685, in CheckpointManager.__init__(self, directory, checkpointers, options, metadata, item_names, item_handlers, logger, handler_registry)
    675   self._cleanup_tmp_directories()
    677 self._step_name_format = (
    678     self._options.step_name_format
    679     or step_lib.standard_name_format(
   (...)
    682     )
    683 )
--> 685 self._checkpoints = self._load_checkpoint_infos()
    687 self._metadata_checkpointer = Checkpointer(
    688     JsonCheckpointHandler(
    689         multiprocessing_options=self._multiprocessing_options
   (...)
    694     temporary_path_class=self._options.temporary_path_class,
    695 )
    696 if self._options.read_only and not self._metadata_path().exists():

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py:1431, in CheckpointManager._load_checkpoint_infos(self)
   1423 """Loads a list of CheckpointInfo for existing checkpoints.
   1424 
   1425 If none are present, returns empty list.
   (...)
   1428   a list of CheckpointInfo, sorted by increasing step.
   1429 """
   1430 start = time.time()
-> 1431 steps = utils.checkpoint_steps(
   1432     self.directory, self._options.single_host_load_and_broadcast
   1433 )
   1434 steps.sort()  # Prefer in-place sort.
   1436 if not steps:

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:698, in checkpoint_steps(checkpoint_dir, single_host_load_and_broadcast)
    696   padded_step_list = multihost.broadcast_one_to_all(padded_step_list)
    697   return [step for step in padded_step_list if step >= 0]
--> 698 return _checkpoint_steps(checkpoint_dir)

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:682, in checkpoint_steps.<locals>._checkpoint_steps(path)
    681 def _checkpoint_steps(path: epath.Path) -> List[int]:
--> 682   return [
    683       step_from_checkpoint_name(s.name) for s in checkpoint_steps_paths(path)
    684   ]

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:683, in <listcomp>(.0)
    681 def _checkpoint_steps(path: epath.Path) -> List[int]:
    682   return [
--> 683       step_from_checkpoint_name(s.name) for s in checkpoint_steps_paths(path)
    684   ]

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:645, in step_from_checkpoint_name(name)
    643 elif tmp_match := re.match(TMP_DIR_STEP_PATTERN, name):
    644   return int(tmp_match.group(1))
--> 645 raise ValueError(f'Unrecognized name format: {name}.')

ValueError: Unrecognized name format: ppo_networks_1024000.

Specifically, when I check the step.py https://github.com/google/orbax/blob/b32d1f9806b96e445d86561d079aa58b336b49be/checkpoint/orbax/checkpoint/_src/path/step.py#L657-L667 it assumes that after the split by _, there are only two members. An input validation of the prefix is needed.

scott-yj-yang avatar Jan 14 '25 16:01 scott-yj-yang

Thank you for reporting this issue. We are working on the fix.

I hope as a work around, you are fine with renaming the prefix to something like pponetworks?

niketkumar avatar Feb 06 '25 19:02 niketkumar

Thank you for your reply. Yes, I am currently naming as PPONetworks to work around this.

scott-yj-yang avatar Feb 06 '25 19:02 scott-yj-yang