Hi, I am trying to save checkpoints using the following code:
Hi, I am trying to save checkpoints using the following code:
options = ocp.CheckpointManagerOptions(
max_to_keep=self.max_checkpoints,
create=True,
best_fn=best_loss,
best_mode="min",
)
self.checkpoint_manager = ocp.CheckpointManager(
os.path.join(self._out_dir, "checkpoints"),
options=options,
item_names=("state", "metadata"),
item_handlers={
"state": ocp.StandardCheckpointHandler(),
"metadata": ocp.JsonCheckpointHandler(),
},
)
The problem happens when I try to create the folder. It checks for multiprocessing using a flag, but the flag is not parsed. Error is:
absl.flags._exceptions.UnparsedFlagAccessError: Trying to access flag --experimental_orbax_use_distributed_process_id before flags were parsed.
Seems to be triggered by this line of code in the library:
absl.flags._exceptions.UnparsedFlagAccessError: Trying to access flag --experimental_orbax_use_distributed_process_id before flags were parsed.
I should note that I am running my code from a child process.
Do you have any suggestions on how to avoid this? I am not sure why the flag is not picked up since it has a default value. I tried parsing them from the main file using args, but then it asks to define my own arguments as flags, which seems as an overhead.
Originally posted by @raresdolga in https://github.com/google/orbax/discussions/962
If I go to "orbax/checkpoint/multihost/utils.py" and comment out the following check, it works: Original:
def process_index() -> int:
if EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.value:
logging.info('Using distributed process id.')
return jax._src.distributed.global_state.process_id # pylint: disable=protected-access
else:
return jax.process_index()
New:
def process_index() -> int:
# if EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.value:
# logging.info('Using distributed process id.')
# return jax._src.distributed.global_state.process_id # pylint: disable=protected-access
# else:
return jax.process_index()
You must have an older version of Orbax because process_index implementation doesn't look like that anymore. The latest version as a try-catch to prevent an issue with the flag from disabling anyone. Can you try updating the version?
try:
experimental_orbax_use_distributed_process_id = (
EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.value
)
except Exception: # pylint: disable=broad-exception-caught
logging.log_first_n(
logging.INFO,
'[thread=%s] Failed to get flag value for'
' EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.',
1,
threading.current_thread().name,
)
experimental_orbax_use_distributed_process_id = False
Hi,
Indeed I updated and now it works.
Closing this issue. Thank you!