orbax
orbax copied to clipboard
Struggling to restore metadata on other device
Hello,
I am trying to load metadata on a new device from a checkpoint via CheckpointManager
API, but somehow struggle to find a solution. Below is a minimal example of what I am trying to do.
First I do "training" on GPU by running:
from orbax import checkpoint as ocp
import pathlib
import jax.numpy as jnp
ckpt_dir = pathlib.Path('.').expanduser().absolute()
ckpt_mngr = ocp.CheckpointManager(
ocp.test_utils.create_empty(ckpt_dir / 'checkpoints'),
item_names=('params', )
)
params = {'a': jnp.array([1.])}
for i in jnp.arange(10):
ckpt_mngr.save(
i,
args=ocp.args.Composite(params=ocp.args.StandardSave(params)),
)
I then copy the checkpoint to my local machine, which has only CPU available. When I try to get metadata
I get the following behaviour.
# Load with the old API
from orbax import checkpoint as ocp
import pathlib
import jax.numpy as jnp
ckpt_dir = pathlib.Path('.').expanduser().absolute()
ckpt_load = ocp.CheckpointManager(
ckpt_dir / 'checkpoints',
{'params': ocp.PyTreeCheckpointer()}
)
latest_step = ckpt_load.latest_step()
ckpt_load.item_metadata(0)
Gives me
File ~/Documents/venvs/mlff/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:164, in _deserialize_sharding_from_json_string(sharding_string)
159 if device := _deserialize_sharding_from_json_string.device_map.get(
160 device_str, None
161 ):
162 return SingleDeviceSharding(device)
--> 164 raise ValueError(
165 f'{ShardingTypes.SINGLE_DEVICE_SHARDING.value} with'
166 f' Device={device_str} was not found in jax.local_devices().'
167 )
169 else:
170 raise NotImplementedError(
171 'Sharding types other than `jax.sharding.NamedSharding` have not been '
172 'implemented.'
173 )
ValueError: SingleDeviceSharding with Device=cuda:0 was not found in jax.local_devices().
Does that mean that calls to metadata
are only available as long as I am on the same device? How else could I get pytree structure without calling model.init
itself? Following this issue https://github.com/google/orbax/issues/648 I could delete the _sharding
file and then restore metadata by setting the restore_kwargs
appropriately. However, this only works with the old API (see below) and seems a bit hacky to me, so I feel I am doing something wrong here. Using the new API
# Load with new API
ckpt_dir = pathlib.Path('.').expanduser().absolute()
ckpt_load = ocp.CheckpointManager(
ckpt_dir / 'checkpoints',
item_names=('params', )
)
latest_step = ckpt_load.latest_step()
ckpt_load.item_metadata(0)
I get
CompositeArgs({})
so no metadata
at all.
We're working on a fix to this, unfortunately the sharding metadata doesn't work that well in every case yet. If you must call metadata
, just delete the sharding file and continue using the old API for now.
Hi @cpgaffney1! Are there any updates on this matter?
I found a PR with fix for metadata reading, but it was not updated since January 17: https://github.com/google/orbax/pull/671
Thanks, Simon
Hi, apologies for the long delay on this - we concluded that using jax.Sharding
directly in the metadata was a bad decision from the start, since it can't always be loaded correctly. We're adding a new representation of the sharding metadata that doesn't try to interact directly with real devices. You can track changes here: https://github.com/google/orbax/blob/main/checkpoint/orbax/checkpoint/sharding_metadata.py (the latest change doesn't have an external pull request yet). I expect this can be fixed by this week or the next.
cc @liangyaning33 who is working on the implementation.
Hi, sorry about the delay. The issue is now fixed. Can you please try again? Thanks!
Hi, I run into a similar issue.
I save my checkpoints with metrics, train it only on CPU and then on the same machine I want to load a checkpoint. But somehow it looks for a cuda:0
device for metadata. Any help would be greatly appreciated!!
Error:
ValueError: SingleDeviceSharding with Device=cuda:0 was not found in jax.local_devices().
Checkpoint Manager creation:
options = CheckpointManagerOptions(
best_fn=lambda metrics: metrics["metric1"],
best_mode="min",
max_to_keep=1,
save_interval_steps=1,
)
checkpoint_manager = CheckpointManager(
directory=checkpoint_dir,
options=options,
)
Checkpoint saving:
for step in tbar:
train_batch = generate_batch(datamodule, "train")
valid_batch = generate_batch(datamodule, "valid")
state_neural_net, current_logs = step_fn(
state_neural_net, train_batch, valid_batch
)
ckpt =state_neural_net
checkpoint_manager.save(
step,
args=StandardSave(ckpt),
metrics={
"metric1": float(metric1),
"metric2": float(metric2),
"metric3": float(metric3),
},
)
checkpoint_manager.wait_until_finished()
And then to load the checkpoint:
# Sets up Ckpt manager as described above
out_class = cls(
jobid=jobid,
logger_path=logger_path,
config=config,
datamodule=datamodule,
)
if step is None:
# Only checks steps with metrics available
step = out_class.checkpoint_manager.best_step()
out_class.neural_net = out_class.checkpoint_manager.restore(
step, args=StandardRestore()
)
But I get the following output and error:
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py:951, in CheckpointManager.restore(self, step, items, restore_kwargs, directory, args)
948 args = typing.cast(args_lib.Composite, args)
950 restore_directory = self._get_read_step_directory(step, directory)
--> 951 restored = self._checkpointer.restore(restore_directory, args=args)
952 if self._single_item:
953 return restored[DEFAULT_ITEM_NAME]
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/async_checkpointer.py:338, in AsyncCheckpointer.restore(self, directory, *args, **kwargs)
336 """See superclass documentation."""
337 self.wait_until_finished()
--> 338 return super().restore(directory, *args, **kwargs)
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py:168, in Checkpointer.restore(self, directory, *args, **kwargs)
166 logging.info('Restoring item from %s.', directory)
167 ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs)
--> 168 restored = self._handler.restore(directory, args=ckpt_args)
169 logging.info('Finished restoring checkpoint from %s.', directory)
170 return restored
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py:464, in CompositeCheckpointHandler.restore(self, directory, args)
462 continue
463 handler = self._get_or_set_handler(item_name, arg)
--> 464 restored[item_name] = handler.restore(
465 self._get_item_directory(directory, item_name), args=arg
466 )
467 return CompositeResults(**restored)
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/standard_checkpoint_handler.py:166, in StandardCheckpointHandler.restore(self, directory, item, args)
163 restore_args = checkpoint_utils.construct_restore_args(args.item)
164 else:
165 restore_args = checkpoint_utils.construct_restore_args(
--> 166 self.metadata(directory)
167 )
168 return super().restore(
169 directory,
170 args=pytree_checkpoint_handler.PyTreeRestoreArgs(
171 item=args.item, restore_args=restore_args
172 ),
173 )
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1453, in PyTreeCheckpointHandler.metadata(self, directory)
1427 """Returns tree metadata.
1428
1429 The result will be a PyTree matching the structure of the saved checkpoint.
(...)
1450 tree containing metadata.
1451 """
1452 try:
-> 1453 return self._get_user_metadata(directory)
1454 except FileNotFoundError as e:
1455 raise FileNotFoundError('Could not locate metadata file.') from e
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1418, in PyTreeCheckpointHandler._get_user_metadata(self, directory)
1415 async def _get_metadata():
1416 return await asyncio.gather(*metadata_ops)
-> 1418 batched_metadatas = asyncio.run(_get_metadata())
1419 for keypath_batch, metadata_batch in zip(
1420 batched_keypaths.values(), batched_metadatas
1421 ):
1422 for keypath, value in zip(keypath_batch, metadata_batch):
File /.conda/envs/condreq/lib/python3.10/site-packages/nest_asyncio.py:30, in _patch_asyncio.<locals>.run(main, debug)
28 task = asyncio.ensure_future(main)
29 try:
---> 30 return loop.run_until_complete(task)
31 finally:
32 if not task.done():
File /.conda/envs/condreq/lib/python3.10/site-packages/nest_asyncio.py:98, in _patch_loop.<locals>.run_until_complete(self, future)
95 if not f.done():
96 raise RuntimeError(
97 'Event loop stopped before Future completed.')
---> 98 return f.result()
File /.conda/envs/condreq/lib/python3.10/asyncio/futures.py:201, in Future.result(self)
199 self.__log_traceback = False
200 if self._exception is not None:
--> 201 raise self._exception.with_traceback(self._exception_tb)
202 return self._result
File /.conda/envs/condreq/lib/python3.10/asyncio/tasks.py:234, in Task.__step(***failed resolving arguments***)
232 result = coro.send(None)
233 else:
--> 234 result = coro.throw(exc)
235 except StopIteration as exc:
236 if self._must_cancel:
237 # Task is cancelled right before coro stops.
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1416, in PyTreeCheckpointHandler._get_user_metadata.<locals>._get_metadata()
1415 async def _get_metadata():
-> 1416 return await asyncio.gather(*metadata_ops)
File /.conda/envs/condreq/lib/python3.10/asyncio/tasks.py:304, in Task.__wakeup(self, future)
302 def __wakeup(self, future):
303 try:
--> 304 future.result()
305 except BaseException as exc:
306 # This may also be a cancellation.
307 self.__step(exc)
File /.conda/envs/condreq/lib/python3.10/asyncio/tasks.py:232, in Task.__step(***failed resolving arguments***)
228 try:
229 if exc is None:
230 # We use the `send` method directly, because coroutines
231 # don't have `__iter__` and `__next__` methods.
--> 232 result = coro.send(None)
233 else:
234 result = coro.throw(exc)
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:1403, in ArrayHandler.metadata(self, infos)
1401 shardings.append(None)
1402 continue
-> 1403 deserialized = _deserialize_sharding_from_json_string(
1404 sharding_string.item()
1405 )
1406 shardings.append(deserialized or None)
1407 else:
File /.conda/envs/condreq/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:166, in _deserialize_sharding_from_json_string(sharding_string)
161 if device := _deserialize_sharding_from_json_string.device_map.get(
162 device_str, None
163 ):
164 return SingleDeviceSharding(device)
--> 166 raise ValueError(
167 f'{ShardingTypes.SINGLE_DEVICE_SHARDING.value} with'
168 f' Device={device_str} was not found in jax.local_devices().'
169 )
171 else:
172 raise NotImplementedError(
173 'Sharding types other than `jax.sharding.NamedSharding` have not been '
174 'implemented.'
175 )
ValueError: SingleDeviceSharding with Device=cuda:0 was not found in jax.local_devices().
datamodule.conditions
Versions: flax==0.7.4 jax==0.4.20 jaxlib==0.4.20+cuda12.cudnn89 optax==0.1.9 orbax-checkpoint==0.5.7
UPDATE SOLVED For me it was solved by downgrading nvidia-cudnn-cu12-9.1.0.70
to match jaxlib-0.4.20+cuda12.cudnn89
. So pip install nvidia-cudnn-cu12==8.9.7.29
.
Also note: prefer to specify the shardings for your tree in args=StandardRestore()
whenever possible. Either that or specify the restore_type
as np.ndarray
. https://orbax.readthedocs.io/en/latest/checkpointing_pytrees.html