Checkpointing issue in a distributed setup
I'm trying to checkpoint the flax's TrainState in the distributed setup, where each node has an access to multiple devices:
def save_checkpoint(args, state, step):
state = unreplicate(state) # flax.jax._utils.unreplicate
with ocp.CheckpointManager(
os.path.abspath(args.checkpointing_path),
options = ocp.CheckpointManagerOptions(save_interval_steps=args.checkpointing_frequency, max_to_keep=10),
item_handlers={
'state': ocp.PyTreeCheckpointHandler(write_tree_metadata=True),
'extra_metadata': ocp.JsonCheckpointHandler(),
}) as mngr:
mngr.save(step, args=ocp.args.Composite(
state=ocp.args.PyTreeSave(state),
extra_metadata=ocp.args.JsonSave(vars(args))
))
However, I'm getting the following error:
File "/lnet/troja/work/people/yorsh/custom_modules/transformers/lib/python3.10/site-packages/orbax/checkpoint/type_handlers.py", line 1340, in serialize
raise ValueError(
ValueError: Cannot serialize host local arrays. Arrays like this are typically obtained using pmap. Consider using fully_replicated_host_local_array_to_global_array in orbax/checkpoint/utils.py to convert your arrays into serializable objects.
If I add the state = jax.tree_util.tree_map(ocp.utils.fully_replicated_host_local_array_to_global_array, state) as the first line in the function, I get another one:
".../orbax/checkpoint/utils.py", line 792, in fully_replicated_host_local_array_to_global_array
raise ValueError('Array must be fully replicated.')
The obtained TrainState is the result of a pmap-ed update with pmean-ed gradients, so it should be equal among devices. However, its tensors got the is_fully_replicated flag set to False, so technically it's not considered so. At the same time, it has is_fully_addressable flag set to True which leads to Orbax raising an exception. Is there any correct way or a hack to force checkpointing?
pmap arrays are always difficult to work with. A few things I would try:
- Use pjit instead of pmap. Probably not a reasonable suggestion for me to make, but I'll include it anyway.
- Use fully_replicated_host_local_array_to_global_array without Flax unreplicate.
- Try using jax.jit to reshard from PmapSharding to NamedSharding. NamedSharding is guaranteed to work in all cases, but PmapSharding can be difficult.
- Just convert everything to a numpy array. If the arrays are fully replicated across all devices, this should be no problem.