orbax icon indicating copy to clipboard operation
orbax copied to clipboard

Checkpointing issue in a distributed setup

Open vladyorsh opened this issue 1 year ago • 1 comments

I'm trying to checkpoint the flax's TrainState in the distributed setup, where each node has an access to multiple devices:

def save_checkpoint(args, state, step):
    state = unreplicate(state) # flax.jax._utils.unreplicate
    with ocp.CheckpointManager(
            os.path.abspath(args.checkpointing_path),
            options = ocp.CheckpointManagerOptions(save_interval_steps=args.checkpointing_frequency, max_to_keep=10),
            item_handlers={
                'state': ocp.PyTreeCheckpointHandler(write_tree_metadata=True),
                'extra_metadata': ocp.JsonCheckpointHandler(),
        }) as mngr:
            mngr.save(step, args=ocp.args.Composite(
                state=ocp.args.PyTreeSave(state),
                extra_metadata=ocp.args.JsonSave(vars(args))
            ))

However, I'm getting the following error:

File "/lnet/troja/work/people/yorsh/custom_modules/transformers/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 1340, in serialize
    raise ValueError(
ValueError: Cannot serialize host local arrays. Arrays like this are typically obtained using pmap. Consider using fully_replicated_host_local_array_to_global_array in orbax/checkpoint/utils.py to convert your arrays into serializable objects.

If I add the state = jax.tree_util.tree_map(ocp.utils.fully_replicated_host_local_array_to_global_array, state) as the first line in the function, I get another one:

".../orbax/checkpoint/utils.py", line 792, in fully_replicated_host_local_array_to_global_array
    raise ValueError('Array must be fully replicated.')

The obtained TrainState is the result of a pmap-ed update with pmean-ed gradients, so it should be equal among devices. However, its tensors got the is_fully_replicated flag set to False, so technically it's not considered so. At the same time, it has is_fully_addressable flag set to True which leads to Orbax raising an exception. Is there any correct way or a hack to force checkpointing?

vladyorsh avatar Apr 24 '24 20:04 vladyorsh

pmap arrays are always difficult to work with. A few things I would try:

  1. Use pjit instead of pmap. Probably not a reasonable suggestion for me to make, but I'll include it anyway.
  2. Use fully_replicated_host_local_array_to_global_array without Flax unreplicate.
  3. Try using jax.jit to reshard from PmapSharding to NamedSharding. NamedSharding is guaranteed to work in all cases, but PmapSharding can be difficult.
  4. Just convert everything to a numpy array. If the arrays are fully replicated across all devices, this should be no problem.

cpgaffney1 avatar Apr 25 '24 17:04 cpgaffney1