xla icon indicating copy to clipboard operation
xla copied to clipboard

Warnings while implementing checkpoint operations for SPMD

Open mfatih7 opened this issue 1 year ago • 3 comments

Hello

While implementing checkpoint operations for SPMD by examining the documentation I get a couple of warnings.

/home/THE_USER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/state_dict_saver.py:29: UserWarning:
'save_state_dict' is deprecated and will be removed in future versions.Please use 'save' instead.
warnings.warn(
/home/THE_USER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/filesystem.py:92: UserWarning:
TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class.
This should only matter to you if you are
using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if tensor.storage().size() != tensor.numel():
/home/THE_USER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/state_dict_loader.py:25: UserWarning:
'load_state_dict' is deprecated and will be removed in future versions. Please use 'load' instead.
warnings.warn(

I think the documentation must be updated.

I use checkpoint save and load operations but I do not have any idea why I get the warning regarding the UntypedStorage.

Is there any way to use xm.save and xser.load functions for SPMD implementations?

If checkpointing via distributed functions is mandatory maybe it could be useful to direct people to use distributed functions for checkpointing operations even for single-core and multi-core training that do not include FSDP operations.

mfatih7 avatar Feb 06 '24 12:02 mfatih7

For SPMD operation, I understand that the better way for checkpointing is using CheckpointManager because it enables saving multiple checkpoints after each epoch. Since the deprecated functions are also used in CheckpointManager the warnings persist even I use CheckpointManager.

mfatih7 avatar Feb 07 '24 17:02 mfatih7

These warnings come from the upstream distributed checkpointing library and are OK to ignore for now. The deprecations will be addressed before the 2.3 release (the save_state_dict and load_state_dict functions were renamed).

For SPMD, the recommendation is to use the torch.distributed.checkpoint APIs, either directly or through CheckpointManager. Using xm.save or directly storing the state_dict will require all shards to be gathered onto the host writing the checkpoint, which is an expensive operation for large models.

EDIT: Thanks for pointing out the outdated documentation! I'll also update it to use the new upstream APIs.

jonb377 avatar Feb 07 '24 22:02 jonb377

Thanks, @jonb377, is it ok to assign this ticket to you now?

ManfeiBai avatar Feb 09 '24 23:02 ManfeiBai