orbax
orbax copied to clipboard
Cannot restore sharded array on different machine
Hello,
This is somewhat similar to #646. During training, I saved my parameters in a sharded manner (could not use aggregate because they were sharded over multiple hosts). Now I can't restore the parameters on a new machine with a different number of devices. Even restoring the parameters into CPU memory as NumPy arrays would be fine, but no matter what I do I get this error:
File ~/miniconda3/envs/test/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py:136, in _deserialize_sharding_from_json_string(sharding_string)
131 axis_names = list(deserialized_dict[_MESH_AXES])
132 partition_spec = tuple(deserialized_dict[_PARTITION_SPEC])
134 sharding = NamedSharding(
135 jax.sharding.Mesh(
--> 136 np.array(jax.devices()).reshape(shape), axis_names=axis_names
137 ),
138 jax.sharding.PartitionSpec(*partition_spec),
139 )
140 return sharding
142 elif (
143 deserialized_dict[_SHARDING_TYPE]
144 == ShardingTypes.SINGLE_DEVICE_SHARDING.value
(...)
147 # their str representation.
148 # Cache tip: See Function Attributes https://peps.python.org/pep-0232/.
ValueError: cannot reshape array of size 2 into shape (1,64)
This happens even upon trying to get the checkpoint structure using PyTreeCheckpointer.metadata(). As a result, I can't specify RestoreArgs to prevent the parameters getting loaded as sharded jax.Arrays.
See my other answer - I think it's the same issue of not specifying the sharding during restore. Please reopen if this is not a duplicate.
As I mentioned at the end, this happens even when calling metadata(), so I can't get the tree structure of the checkpoint to specify a sharding. It would also be nice if a more informative error could be thrown, and/or an option to just restore the checkpoint in CPU memory (without having to know the tree structure). If the mesh on the restoring machine is incompatible with the saved sharding, it would even make sense to me to fall back to a reasonable default (e.g., loading into CPU memory or device 0 memory).
@cpgaffney1 I can't re-open the issue, but this seems like a concrete bug that is separate from #646.
Hmm I see the problem now - there does seem to be a bug in the logic that restores the sharding from the metadata, in assuming the devices can be reshaped into the shape given by the checkpoint, which is obviously not true in all cases. Thanks for spotting that.
I'll look into a fix. In the meantime, what you could do is this: specify RestoreArgs where restore_type=np.ndarray. You do need to know the tree structure because the restore_type is specified individually for each leaf. This design is essentially settled, I don't think we'll provide a fallback, since constructing a tree of restore args is very easy with tree_map. The fact that the metadata function isn't working makes this trickier, but I think it will not take too long for us to fix.
Thanks. I was able to work around it with this:
if os.path.exists(f"{local_path}/{step}/_sharding"):
os.remove(f"{local_path}/{step}/_sharding")
manager = ocp.CheckpointManager(local_path, ocp.PyTreeCheckpointer())
structure = manager.item_metadata(step)
params = manager.restore(
step,
restore_kwargs={
"restore_args": jax.tree_map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray), structure
)
},
)
I do think this is quite verbose, even without the first two lines, but if the design is settled I understand.
I agree it's a bit verbose with PyTreeCheckpointHandler, which is designed for maximum flexibility. You can try StandardCheckpointHandler - e.g:
manager = ocp.CheckpointManager(local_path, ocp.StandardCheckpointer())
structure = manager.item_metadata(step)
target = jax.tree_util.tree_map(lambda x: np.zeros(x.shape), structure)
params = manager.restore(step, target)
(This depends on the sharding file being deleted of course, unless you patch the change above, in which case it might work regardless.)
That should restore as np.ndarray, since the target values are of type np.ndarray.
Ultimately you need at least one line instructing it what type to restore for each leaf.
Ah I see, I wasn't aware of StandardCheckpointer. I don't think it's covered in the documentation. What are the concrete differences between StandardCheckpointer and PyTreeCheckpointer?
Documentation is just now being updated, there will be a substantial amount of new docs!
StandardCheckpointHandler is really just a wrapper around PyTreeCheckpointHandler that allows you to avoid having to deal with RestoreArgs. It enforces certain standard leaf types.