xla
xla copied to clipboard
Warnings while implementing checkpoint operations for SPMD
Hello
While implementing checkpoint operations for SPMD by examining the documentation I get a couple of warnings.
/home/THE_USER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/state_dict_saver.py:29: UserWarning:
'save_state_dict' is deprecated and will be removed in future versions.Please use 'save' instead.
warnings.warn(
/home/THE_USER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/filesystem.py:92: UserWarning:
TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class.
This should only matter to you if you are
using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
if tensor.storage().size() != tensor.numel():
/home/THE_USER/env3_8/lib/python3.8/site-packages/torch/distributed/checkpoint/state_dict_loader.py:25: UserWarning:
'load_state_dict' is deprecated and will be removed in future versions. Please use 'load' instead.
warnings.warn(
I think the documentation must be updated.
I use checkpoint save and load operations but I do not have any idea why I get the warning regarding the UntypedStorage
.
Is there any way to use xm.save
and xser.load
functions for SPMD implementations?
If checkpointing via distributed functions is mandatory maybe it could be useful to direct people to use distributed functions for checkpointing operations even for single-core and multi-core training that do not include FSDP operations.
For SPMD operation, I understand that the better way for checkpointing is using CheckpointManager because it enables saving multiple checkpoints after each epoch. Since the deprecated functions are also used in CheckpointManager the warnings persist even I use CheckpointManager.
These warnings come from the upstream distributed checkpointing library and are OK to ignore for now. The deprecations will be addressed before the 2.3 release (the save_state_dict
and load_state_dict
functions were renamed).
For SPMD, the recommendation is to use the torch.distributed.checkpoint
APIs, either directly or through CheckpointManager
. Using xm.save
or directly storing the state_dict will require all shards to be gathered onto the host writing the checkpoint, which is an expensive operation for large models.
EDIT: Thanks for pointing out the outdated documentation! I'll also update it to use the new upstream APIs.
Thanks, @jonb377, is it ok to assign this ticket to you now?