orbax icon indicating copy to clipboard operation
orbax copied to clipboard

How to restore on a CPU a checkpoint saved on a GPU?

Open paulbarbier opened this issue 2 years ago • 1 comments

Hi,

I made checkpoints of my weights and some Jax arrays using the following snippet on a GPU instance (v100):

import orbax.checkpointer as ocp

checkpointer = ocp.PyTreeCheckpointer()
checkpointer.save(path, checkpoint)

Now I want to make plots of this data on my local machine using:

import orbax.checkpoint as ocp
checkpointer = ocp.StandardCheckpointHandler()

checkpoint = checkpointer.restore(path, item=None)

and I got the following error:

File [/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:162](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:162), in _deserialize_sharding_from_json_string(sharding_string)
    [157](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:157)   if device := _deserialize_sharding_from_json_string.device_map.get(
    [158](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:158)       device_str, None
    [159](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:159)   ):
    [160](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:160)     return SingleDeviceSharding(device)
--> [162](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:162)   raise ValueError(
    [163](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:163)       f'{ShardingTypes.SINGLE_DEVICE_SHARDING.value} with'
    [164](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:164)       f' Device={device_str} was not found in jax.local_devices().'
    [165](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:165)   )
    [167](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:167) else:
    [168](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:168)   raise NotImplementedError(
    [169](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:169)       'Sharding types other than `jax.sharding.NamedSharding` have not been '
    [170](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:170)       'implemented.'
    [171](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/Caskroom/miniforge/base/envs/pgm/lib/python3.9/site-packages/orbax/checkpoint/type_handlers.py:171)   )

ValueError: SingleDeviceSharding with Device=cuda:0 was not found in jax.local_devices().

My workaround for now is to replace in _sharding all the occurrences of cuda:0 with the output of str(jax.local_devices()[0]) but that might be tricky in more complex situations.

In PyTorch you can specify map_location in the torch.load function to address this kind of issue.

How would you handle this?

Thanks for your feedback.

paulbarbier avatar Dec 22 '23 00:12 paulbarbier

You're restoring with a different jax.sharding.Sharding. For the item argument of restore, you need to pass a tree matching the checkpoint tree. The values of the tree should be array-like objects matching the properties of the checkpoint, but ensure that the .sharding property is set. Typically the tree values you specify should be jax.ShapeDtypeStruct or jax.Array. Also check out the metadata method to get properties about the checkpoint tree.

cpgaffney1 avatar Jan 02 '24 16:01 cpgaffney1