orbax icon indicating copy to clipboard operation
orbax copied to clipboard

saving with default/global mesh is broken

Open PhilipVinc opened this issue 1 month ago • 2 comments

On main and jax 0.7/0.8, if a global mesh is set with jax.sharding.set_mesh it is impossible to save a checkpoint because of the error shown below.

Am I doing something wrong, or is this an issue in orbax?

import numpy as np
import orbax.checkpoint as ocp
import jax
import jax.numpy as jnp

sharding = jax.sharding.NamedSharding(
    jax.sharding.Mesh(jax.devices(), ('model',)),
    jax.sharding.PartitionSpec(),
)
# Commenting this line below will break cptr.save
jax.sharding.set_mesh(sharding.mesh)

create_sharded_array = lambda x: jax.device_put(x, sharding)
state = {
    'a': np.arange(16),
    'b': np.ones(16),
}
state = jax.tree.map(create_sharded_array, state)
abstract_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, state)

path = ocp.test_utils.erase_and_create_empty('/tmp/basic/')
# Make sure to use async for improved performance!
ckptr = ocp.AsyncCheckpointer(ocp.StandardCheckpointHandler())

ckptr.save(path / '1', args=ocp.args.StandardSave(state))

error:


File ~/Nextcloud/Codes/Python/netket_pro/.venv/lib/python3.13/site-packages/orbax/checkpoint/_src/serialization/jax_array_handlers.py:262, in _serialize_arrays(arrays, infos, args, dispatcher, replica_id, use_replica_parallel, min_slice_bytes_for_replica_parallel, max_replicas_for_replica_parallel, primary_host, metadata_key, array_metadata_store, enable_replica_parallel_separate_folder, ext_metadata)
    259 """D2H transfer and serialize arrays using dispatcher if provided."""
    260 if dispatcher is None:
    261   # Complete D2H transfer in parallel for each array.
--> 262   values_on_host = replica_slices.transfer_arrays_to_host(
    263       arrays,
    264       replica_id,
    265       use_replica_parallel,
    266       enable_pinned_host_transfer=infos[0].enable_pinned_host_transfer,
    267       min_slice_bytes_for_replica_parallel=min_slice_bytes_for_replica_parallel,
    268       max_replicas_for_replica_parallel=max_replicas_for_replica_parallel,
    269   )
    270   return future.CommitFutureAwaitingContractedSignals(
    271       _async_serialize_replica_slices(
    272           values_on_host,
   (...)    282       name='array_type_handler',
    283   )
    284 else:
....
File ~/Nextcloud/Codes/Python/netket_pro/.venv/lib/python3.13/site-packages/jax/_src/pjit.py:159, in _python_pjit_helper(fun, jit_info, *args, **kwargs)
    156   fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
    157   msg = stages._device_assignment_mismatch_error(
    158       fun_name, fails, args_flat, 'jit', p.arg_names)
--> 159   raise ValueError(msg) from None
    160 except dtypes.InvalidInputException as e:
    161   arg_names = [''] * len(args_flat) if p.arg_names is None else p.arg_names

ValueError: Received incompatible devices for jitted computation. Got argument args[0] of slice with shape int32[16] and device ids [0] on platform CPU and jit's context mesh with device ids [0, 1, 2, 3] on platform CPU

PhilipVinc avatar Nov 05 '25 11:11 PhilipVinc

It seems that if a global mesh is set, I need to 'unset' it to make 'save' work:

null_mesh = jax.make_mesh((),())
with jax.sharding.set_mesh(null_mesh):
    ckptr.save(path / '1', args=ocp.args.StandardSave(state))
    ckptr.wait_until_finished()

but this is not documented anywhere...

PhilipVinc avatar Nov 05 '25 11:11 PhilipVinc

Thanks for raising this - Orbax extracts the mesh info from arrays themselves, which is a little more flexible since you're not limited to a global mesh for all arrays.

I think the global mesh setting should just be ignored in this case as it is redundant info on top of what the arrays already contain.

The issue may be that the D2H operation runs into trouble because we need to specify CPU devices as the intermediate locations of the arrays before saving, which would conflict with the global mesh setting, which uses accelerators.

We'll get back to you with a fix!

cpgaffney1 avatar Nov 05 '25 16:11 cpgaffney1

Hi there, thanks again for the report. Quick update: the issue should be fixed by https://github.com/google/orbax/commit/5a0bd479c39b86b31ef855e7770b461a18f3a3c8

As pointed out in the report, when a global mesh is active, jax.lax.slice_in_dim can fail with an incompatible device error if the array being sliced is not on the devices specified by the global mesh.

The fix is to temporarily override the mesh context before calling jax.lax.slice_in_dim. A new mesh is constructed using the devices from ReplicaSlice.unsliced_data.sharding and set via jax.sharding.set_mesh.

https://github.com/google/orbax/blob/b2c2b1223cb4ede0b6503a205faef143affe4692/checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py#L93-L106

If you have a chance to test this against your workload, please let us know if there are any further issues!

JustinPan-goog avatar Nov 17 '25 15:11 JustinPan-goog