saving with default/global mesh is broken
On main and jax 0.7/0.8, if a global mesh is set with jax.sharding.set_mesh it is impossible to save a checkpoint because of the error shown below.
Am I doing something wrong, or is this an issue in orbax?
import numpy as np
import orbax.checkpoint as ocp
import jax
import jax.numpy as jnp
sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), ('model',)),
jax.sharding.PartitionSpec(),
)
# Commenting this line below will break cptr.save
jax.sharding.set_mesh(sharding.mesh)
create_sharded_array = lambda x: jax.device_put(x, sharding)
state = {
'a': np.arange(16),
'b': np.ones(16),
}
state = jax.tree.map(create_sharded_array, state)
abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)
path = ocp.test_utils.erase_and_create_empty('/tmp/basic/')
# Make sure to use async for improved performance!
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())
ckptr.save(path / '1', args=ocp.args.StandardSave(state))
error:
File ~/Nextcloud/Codes/Python/netket_pro/.venv/lib/python3.13/site-packages/orbax/checkpoint/_src/serialization/jax_array_handlers.py:262, in _serialize_arrays(arrays, infos, args, dispatcher, replica_id, use_replica_parallel, min_slice_bytes_for_replica_parallel, max_replicas_for_replica_parallel, primary_host, metadata_key, array_metadata_store, enable_replica_parallel_separate_folder, ext_metadata)
259 """D2H transfer and serialize arrays using dispatcher if provided."""
260 if dispatcher is None:
261 # Complete D2H transfer in parallel for each array.
--> 262 values_on_host = replica_slices.transfer_arrays_to_host(
263 arrays,
264 replica_id,
265 use_replica_parallel,
266 enable_pinned_host_transfer=infos[0].enable_pinned_host_transfer,
267 min_slice_bytes_for_replica_parallel=min_slice_bytes_for_replica_parallel,
268 max_replicas_for_replica_parallel=max_replicas_for_replica_parallel,
269 )
270 return future.CommitFutureAwaitingContractedSignals(
271 _async_serialize_replica_slices(
272 values_on_host,
(...) 282 name='array_type_handler',
283 )
284 else:
....
File ~/Nextcloud/Codes/Python/netket_pro/.venv/lib/python3.13/site-packages/jax/_src/pjit.py:159, in _python_pjit_helper(fun, jit_info, *args, **kwargs)
156 fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
157 msg = stages._device_assignment_mismatch_error(
158 fun_name, fails, args_flat, 'jit', p.arg_names)
--> 159 raise ValueError(msg) from None
160 except dtypes.InvalidInputException as e:
161 arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names
ValueError: Received incompatible devices for jitted computation. Got argument args[0] of slice with shape int32[16] and device ids [0] on platform CPU and jit's context mesh with device ids [0, 1, 2, 3] on platform CPU
It seems that if a global mesh is set, I need to 'unset' it to make 'save' work:
null_mesh = jax.make_mesh((),())
with jax.sharding.set_mesh(null_mesh):
ckptr.save(path / '1', args=ocp.args.StandardSave(state))
ckptr.wait_until_finished()
but this is not documented anywhere...
Thanks for raising this - Orbax extracts the mesh info from arrays themselves, which is a little more flexible since you're not limited to a global mesh for all arrays.
I think the global mesh setting should just be ignored in this case as it is redundant info on top of what the arrays already contain.
The issue may be that the D2H operation runs into trouble because we need to specify CPU devices as the intermediate locations of the arrays before saving, which would conflict with the global mesh setting, which uses accelerators.
We'll get back to you with a fix!
Hi there, thanks again for the report. Quick update: the issue should be fixed by https://github.com/google/orbax/commit/5a0bd479c39b86b31ef855e7770b461a18f3a3c8
As pointed out in the report, when a global mesh is active, jax.lax.slice_in_dim can fail with an incompatible device error if the array being sliced is not on the devices specified by the global mesh.
The fix is to temporarily override the mesh context before calling jax.lax.slice_in_dim. A new mesh is constructed using the devices from ReplicaSlice.unsliced_data.sharding and set via jax.sharding.set_mesh.
https://github.com/google/orbax/blob/b2c2b1223cb4ede0b6503a205faef143affe4692/checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py#L93-L106
If you have a chance to test this against your workload, please let us know if there are any further issues!